-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathresponse_generator.py
More file actions
477 lines (418 loc) · 20 KB
/
response_generator.py
File metadata and controls
477 lines (418 loc) · 20 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
"""
Empathetic Response Generator for Crisis Intervention
Generates trauma-informed, supportive responses using Gemma model
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Dict, List, Optional, Tuple
import logging
import json
import random
logger = logging.getLogger(__name__)
class CrisisResponseGenerator:
"""
Generates empathetic, trauma-informed responses for crisis situations
"""
def __init__(self, model_name: str = "google/gemma-2b-it"):
"""
Initialize the response generator
Args:
model_name: Hugging Face model identifier for Gemma
"""
self.model_name = model_name
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and tokenizer
self._load_model()
# Response templates for different crisis types
self.response_templates = {
"self_harm": {
"immediate": [
"I can hear that you're in a lot of pain right now. You don't have to go through this alone.",
"I'm really concerned about you. Your life has value, even when it doesn't feel that way.",
"I want you to know that what you're feeling right now is valid, and there are people who care about you."
],
"supportive": [
"It takes courage to reach out when you're struggling. I'm glad you did.",
"You're not alone in this. Many people have felt this way and found ways to cope.",
"Your feelings are important, and so are you."
]
},
"suicide": {
"immediate": [
"I'm very concerned about you right now. Your life matters, and I want to help you stay safe.",
"I can hear how much pain you're in. Please know that you don't have to face this alone.",
"I care about you, and I want to make sure you're safe. Can we talk about what's happening?"
],
"supportive": [
"Thank you for sharing this with me. It takes incredible strength to be so honest.",
"I'm here with you. You don't have to carry this burden alone.",
"Your life has meaning, even when it's hard to see right now."
]
},
"violence": {
"immediate": [
"I'm concerned about your safety. Violence is never the answer, and there are better ways to handle this.",
"I can hear that you're very angry right now. Let's talk about what's really bothering you.",
"I want to help you find a safer way to express these feelings."
],
"supportive": [
"It's okay to feel angry, but we need to find safe ways to express it.",
"I'm here to listen and help you work through these feelings.",
"There are people who can help you resolve this situation safely."
]
},
"abuse": {
"immediate": [
"I'm so sorry this is happening to you. You don't deserve to be treated this way.",
"Your safety is the most important thing right now. You're not alone.",
"I believe you, and I want to help you get to safety."
],
"supportive": [
"It took courage to share this with me. You're not to blame for what happened.",
"You deserve to be treated with respect and kindness.",
"I'm here to support you in whatever way you need."
]
},
"overdose": {
"immediate": [
"I'm very concerned about your safety right now. Please don't take any more.",
"Your life is valuable, and I want to help you stay safe.",
"If you've already taken something, please call emergency services immediately."
],
"supportive": [
"I'm glad you're reaching out. You don't have to face this alone.",
"There are people who care about you and want to help you through this.",
"Recovery is possible, and you deserve support."
]
}
}
# Crisis resource database
self.crisis_resources = self._load_crisis_resources()
def _load_model(self):
"""Load the Gemma model for response generation"""
try:
logger.info(f"Loading response model: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
trust_remote_code=True
)
if self.device == "cpu":
self.model = self.model.to(self.device)
logger.info("Response model loaded successfully")
except Exception as e:
logger.error(f"Error loading response model: {e}")
raise
def _load_crisis_resources(self) -> Dict[str, List[Dict]]:
"""Load comprehensive crisis resources and helplines"""
return {
"general": [
{
"name": "National Suicide Prevention Lifeline",
"number": "988",
"text": "Text HOME to 741741",
"description": "24/7 crisis support for suicide prevention",
"available": "24/7"
},
{
"name": "Crisis Text Line",
"number": "Text HOME to 741741",
"description": "Free, 24/7 crisis support via text",
"available": "24/7"
}
],
"self_harm": [
{
"name": "Self-Injury Outreach & Support",
"website": "sioutreach.org",
"description": "Resources and support for self-injury recovery"
},
{
"name": "To Write Love on Her Arms",
"website": "twloha.com",
"description": "Hope and help for people struggling with depression, addiction, self-injury, and suicide"
}
],
"suicide": [
{
"name": "National Suicide Prevention Lifeline",
"number": "988",
"text": "Text HOME to 741741",
"description": "24/7 crisis support for suicide prevention",
"available": "24/7"
},
{
"name": "American Foundation for Suicide Prevention",
"website": "afsp.org",
"description": "Resources, support groups, and prevention programs"
}
],
"violence": [
{
"name": "National Domestic Violence Hotline",
"number": "1-800-799-7233",
"text": "Text START to 88788",
"description": "24/7 support for domestic violence",
"available": "24/7"
},
{
"name": "National Sexual Assault Hotline",
"number": "1-800-656-4673",
"description": "24/7 support for sexual assault survivors",
"available": "24/7"
}
],
"abuse": [
{
"name": "National Domestic Violence Hotline",
"number": "1-800-799-7233",
"text": "Text START to 88788",
"description": "24/7 support for domestic violence",
"available": "24/7"
},
{
"name": "Childhelp National Child Abuse Hotline",
"number": "1-800-4-A-CHILD (1-800-422-4453)",
"description": "24/7 support for child abuse",
"available": "24/7"
}
],
"overdose": [
{
"name": "SAMHSA National Helpline",
"number": "1-800-662-4357",
"description": "24/7 treatment referral and information service",
"available": "24/7"
},
{
"name": "National Poison Control Center",
"number": "1-800-222-1222",
"description": "24/7 poison emergency support",
"available": "24/7"
}
]
}
def generate_response(self,
crisis_type: str,
user_message: str,
confidence: float,
immediate_risk: bool = False,
safety_analysis: Dict = None) -> Dict[str, any]:
"""
Generate an empathetic response for a crisis situation using Gemma safety analysis
Args:
crisis_type: Type of crisis detected
user_message: User's original message
confidence: Confidence level of crisis detection
immediate_risk: Whether there's immediate risk
safety_analysis: Detailed safety analysis from Gemma
Returns:
Dictionary containing response and resources
"""
try:
# Determine response urgency based on safety analysis
response_type = self._determine_response_type(confidence, immediate_risk, safety_analysis)
if crisis_type in self.response_templates:
base_response = random.choice(self.response_templates[crisis_type][response_type])
else:
base_response = "I'm here to listen and support you. You don't have to face this alone."
# Generate personalized response using Gemma with safety context
personalized_response = self._generate_safety_aware_response(
user_message, base_response, crisis_type, immediate_risk, safety_analysis
)
# Get relevant resources based on safety analysis
resources = self._get_relevant_resources(crisis_type, immediate_risk)
# Create safety plan if needed
safety_plan = self._create_safety_plan(crisis_type, immediate_risk) if immediate_risk else None
return {
"response": personalized_response,
"resources": resources,
"safety_plan": safety_plan,
"immediate_risk": immediate_risk,
"crisis_type": crisis_type,
"confidence": confidence,
"safety_analysis": safety_analysis
}
except Exception as e:
logger.error(f"Error generating response: {e}")
return self._get_fallback_response(crisis_type, immediate_risk)
def _determine_response_type(self, confidence: float, immediate_risk: bool, safety_analysis: Dict = None) -> str:
"""Determine response type based on safety analysis"""
if immediate_risk or confidence > 0.8:
return "immediate"
elif safety_analysis:
# Check for high severity in safety analysis
for category, data in safety_analysis.items():
if data.get("severity") == "high" and data.get("detected", False):
return "immediate"
elif confidence > 0.6:
return "supportive"
else:
return "supportive"
def _generate_safety_aware_response(self,
user_message: str,
base_response: str,
crisis_type: str,
immediate_risk: bool,
safety_analysis: Dict = None) -> str:
"""
Use Gemma model to generate a safety-aware personalized response
Args:
user_message: User's original message
base_response: Base response template
crisis_type: Type of crisis
immediate_risk: Whether there's immediate risk
safety_analysis: Detailed safety analysis from Gemma
Returns:
Personalized response text
"""
try:
# Create a safety-aware prompt for response generation
risk_level = "immediate crisis" if immediate_risk else "crisis situation"
# Build safety context for the prompt
safety_context = ""
if safety_analysis:
detected_categories = [cat for cat, data in safety_analysis.items()
if data.get("detected", False)]
if detected_categories:
safety_context = f"\nSafety Analysis: Detected {', '.join(detected_categories)} with severity levels: " + \
", ".join([f"{cat}: {data.get('severity', 'unknown')}"
for cat, data in safety_analysis.items()
if data.get("detected", False)])
prompt = f"""<start_of_turn>user
You are a compassionate crisis counselor with advanced safety training. Generate a supportive, empathetic response for someone in a {risk_level} related to {crisis_type}.
Safety Guidelines:
- Be warm, non-judgmental, and validating
- Acknowledge their pain without minimizing it
- Offer hope and support while maintaining safety awareness
- Keep response under 150 words
- Use "I" statements to show care
- Avoid giving medical advice
- Prioritize immediate safety if high risk detected
- Provide appropriate crisis resources
User's message: "{user_message}"
Base response: "{base_response}"{safety_context}
Generate a personalized, empathetic, and safety-aware response:
<end_of_turn>
<start_of_turn>model"""
# Tokenize and generate
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=200,
temperature=0.6, # Slightly lower temperature for more consistent safety-aware responses
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the model's response
if "<start_of_turn>model" in response:
model_response = response.split("<start_of_turn>model")[-1].strip()
else:
model_response = response
# Clean up the response
model_response = model_response.replace("<end_of_turn>", "").strip()
# Fallback to base response if model response is too short or inappropriate
if len(model_response) < 20 or any(word in model_response.lower() for word in ["sorry", "can't help", "unable"]):
return base_response
return model_response
except Exception as e:
logger.error(f"Error in safety-aware response generation: {e}")
return base_response
def _get_relevant_resources(self, crisis_type: str, immediate_risk: bool) -> List[Dict]:
"""
Get relevant crisis resources based on crisis type
Args:
crisis_type: Type of crisis
immediate_risk: Whether there's immediate risk
Returns:
List of relevant resources
"""
resources = []
# Always include general crisis resources
resources.extend(self.crisis_resources["general"])
# Add specific resources for the crisis type
if crisis_type in self.crisis_resources:
resources.extend(self.crisis_resources[crisis_type])
# Prioritize immediate resources if there's immediate risk
if immediate_risk:
resources = [r for r in resources if r.get("available") == "24/7"]
return resources[:5] # Limit to 5 resources
def _create_safety_plan(self, crisis_type: str, immediate_risk: bool) -> Dict[str, any]:
"""
Create a safety plan for immediate risk situations
Args:
crisis_type: Type of crisis
immediate_risk: Whether there's immediate risk
Returns:
Safety plan dictionary
"""
if not immediate_risk:
return None
safety_plan = {
"immediate_actions": [
"Call 911 or go to the nearest emergency room if you're in immediate danger",
"Remove any means of self-harm from your immediate environment",
"Stay with a trusted person or in a public place",
"Call a crisis hotline: 988 (Suicide & Crisis Lifeline)"
],
"coping_strategies": [
"Practice deep breathing exercises",
"Use grounding techniques (5-4-3-2-1 method)",
"Reach out to a trusted friend or family member",
"Engage in a calming activity you enjoy"
],
"warning_signs": [
"Feeling hopeless or worthless",
"Having thoughts of self-harm or suicide",
"Feeling isolated or alone",
"Changes in sleep or appetite"
],
"emergency_contacts": [
"National Suicide Prevention Lifeline: 988",
"Crisis Text Line: Text HOME to 741741",
"Emergency Services: 911"
]
}
return safety_plan
def _get_fallback_response(self, crisis_type: str, immediate_risk: bool) -> Dict[str, any]:
"""
Get a fallback response if generation fails
Args:
crisis_type: Type of crisis
immediate_risk: Whether there's immediate risk
Returns:
Fallback response dictionary
"""
if immediate_risk:
response = "I'm very concerned about your safety right now. Please call 988 (Suicide & Crisis Lifeline) or 911 immediately. You don't have to face this alone."
else:
response = "I'm here to listen and support you. You don't have to face this alone. Please consider reaching out to a crisis hotline or mental health professional."
return {
"response": response,
"resources": self.crisis_resources["general"][:3],
"safety_plan": self._create_safety_plan(crisis_type, immediate_risk) if immediate_risk else None,
"immediate_risk": immediate_risk,
"crisis_type": crisis_type,
"confidence": 0.5
}
# Example usage
if __name__ == "__main__":
generator = CrisisResponseGenerator()
# Test response generation
response = generator.generate_response(
crisis_type="suicide",
user_message="I don't want to live anymore",
confidence=0.8,
immediate_risk=True
)
print("Generated Response:")
print(json.dumps(response, indent=2))