-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
82 lines (66 loc) · 2.28 KB
/
app.py
File metadata and controls
82 lines (66 loc) · 2.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from flask import Flask, request, send_file, jsonify
from rembg import remove
import io
import os
import tensorflow as tf
from shared import TYPE_DETECTION_OUT_OUTPUT_H5_PATH
app = Flask(__name__)
API_KEY = os.getenv("API_KEY")
@app.before_request
def require_api_key():
key = request.headers.get("X-API-Key")
if API_KEY and key != API_KEY:
return jsonify({"error": "Forbidden"}), 403
@app.route('/remove', methods=['POST'])
def remove_background():
"""
POST /remove
form-data: file -> image to process
returns: PNG image with background removed
"""
if 'file' not in request.files:
return jsonify({'error': 'no file provided'}), 400
uploaded = request.files['file']
data = uploaded.read()
try:
out = remove(data)
except Exception as e:
return jsonify({'error': str(e)}), 500
return send_file(io.BytesIO(out), mimetype='image/png')
_model = None
def load_classification_model():
global _model
if _model is None:
if not os.path.exists(TYPE_DETECTION_OUT_OUTPUT_H5_PATH):
raise FileNotFoundError(f"Model file not found at {TYPE_DETECTION_OUT_OUTPUT_H5_PATH}. Create it using train.py")
_model = tf.keras.models.load_model(TYPE_DETECTION_OUT_OUTPUT_H5_PATH)
@app.route('/classify', methods=['POST'])
def classify_image():
"""
POST /classify
form-data: file -> image to classify
returns: JSON labelIndex: int
"""
if 'file' not in request.files:
return jsonify({'error': 'no file provided'}), 400
uploaded = request.files['file']
data = uploaded.read()
try:
load_classification_model()
except Exception as e:
return jsonify({'error': f"model load error: {e}"}), 500
try:
# decode and preprocess
img = tf.io.decode_image(data, channels=3)
img = tf.image.resize(img, [224, 224])
img = tf.cast(img, tf.float32) / 255.0
img = tf.expand_dims(img, axis=0) # batch dim
preds = _model.predict(img)
label_index = int(tf.argmax(preds[0]).numpy())
# cleanup
tf.keras.backend.clear_session()
return jsonify(label_index)
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)