Browse Source

fix error

liuyuqi-dellpc 1 year ago
parent
commit
240bcd5658

+ 0 - 10
README.md

@@ -32,16 +32,6 @@ python app.py
 使用模型进行预测,同时使用`pdx.seg.visualize`将结果可视化,可视化结果将保存到`./output/deeplab`下,其中`weight`代表原图的权重,即mask可视化结果与原图权重因子。
 
 
-```python
-import paddlex as pdx
-model = pdx.deploy.Predictor('inference_model')
-image_name = 'optic_disc_seg/JPEGImages/H0005.jpg'
-result = model.predict(image_name)
-pdx.seg.visualize(image_name, result, weight=0.4, save_dir='./output/deeplab')
-```
-
-    2021-01-23 08:16:45 [INFO]	The visualized result is saved as ./output/deeplab/visualize_H0005.jpg
-
 
 
 ```python

+ 37 - 13
apps/__init__.py

@@ -1,8 +1,17 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+'''
+@Contact :   liuyuqi.gov@msn.cn
+@Time    :   2023/11/09 14:27:17
+@License :   Copyright © 2017-2022 liuyuqi. All Rights Reserved.
+@Desc    :   create flask app
+'''
+
 import os
 from flask import Flask
 from apps.config import config
 from apps.views import init_blueprints
-
+from apps.extensions import init_plugins
 
 def create_app(config_name="default") -> Flask:
     ''' create app '''
@@ -17,6 +26,9 @@ def create_app(config_name="default") -> Flask:
 
     init_dir()
     init_blueprints(app)
+    init_plugins(app)
+    init_hook(app)
+
     return app
 
 def init_dir():
@@ -29,15 +41,27 @@ def init_dir():
         if not os.path.exists(ff):
             os.makedirs(ff)
 
-# 添加header解决跨域
-# @app.after_request
-# def after_request(response):
-#     response.headers['Access-Control-Allow-Origin'] = '*'
-#     response.headers['Access-Control-Allow-Credentials'] = 'true'
-#     response.headers['Access-Control-Allow-Methods'] = 'POST'
-#     response.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With'
-#     return response
-
-    # with app.app_context():
-    #     current_app.model = deploy.Predictor(
-    #         './core/net/inference_model', use_gpu=True)
+def init_hook(app: Flask):
+    ''' init hook '''
+
+    @app.after_request
+    def after_request(response):
+        ''' resolve cross-domain '''
+
+        # response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
+        # response.headers["Expires"] = 0
+        # response.headers["Pragma"] = "no-cache"
+        
+        # response.set_cookie('remember_token', '', expires=0)
+        # response.headers.add('X-Version', app.config['CURRENT_VERSION'])
+        # response.headers.add('X-Env', app.config['DEPLOY_ENV'])
+        
+        response.headers['Access-Control-Allow-Origin'] = '*'
+        response.headers['Access-Control-Allow-Credentials'] = 'true'
+        response.headers['Access-Control-Allow-Methods'] = 'POST'
+        response.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With'
+        return response
+
+    @app.before_request
+    def before_request():
+        pass

+ 12 - 3
apps/config.py

@@ -1,10 +1,19 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+'''
+@Contact :   liuyuqi.gov@msn.cn
+@Time    :   2023/11/09 14:21:32
+@License :   Copyright © 2017-2022 liuyuqi. All Rights Reserved.
+@Desc    :   config file
+'''
+
 import os
 import random
 import string
 import logging
 from datetime import timedelta
 from dotenv import load_dotenv
-
+import paddlex as pdx
 
 if os.path.exists('.env'):
     load_dotenv('.env', verbose=True)
@@ -19,7 +28,7 @@ class BaseConfig:
         SECRET_KEY = ''.join(random.choice(string.ascii_lowercase)
                              for i in range(32))
 # app.secret_key = 'secret!'
-    app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
+    UPLOAD_FOLDER = "tmp/upload"
 
 # werkzeug_logger = rel_log.getLogger('werkzeug')
 # werkzeug_logger.setLevel(rel_log.ERROR)
@@ -105,7 +114,7 @@ class BaseConfig:
     @staticmethod
     def init_app(app):
         pass
-
+    
 class DevelopmentConfig(BaseConfig):
     ''' 开发环境配置, 开启调试模式, 使用 sqlite '''
     DEBUG = True

+ 16 - 1
apps/extensions/__init__.py

@@ -1 +1,16 @@
-from .init_sqlalhemy import db
+from .init_sqlalhemy import db, init_databases
+from flask import Flask
+import paddlex as pdx
+
+def init_plugins(app:Flask):
+    init_databases(app)
+    init_paddle(app)
+
+def init_paddle(app):
+    with app.app_context():
+        use_gpu=True
+        if not pdx.device.is_compiled_with_cuda():
+            print("PaddlePaddle is not compiled with CUDA. CPU will be used.")
+            use_gpu = False
+        app.pdx_model = pdx.deploy.Predictor('./core/net/inference_model', use_gpu=use_gpu)
+

+ 15 - 0
apps/extensions/init_sqlalhemy.py

@@ -53,3 +53,18 @@ ma = Marshmallow()
 
 migrate = Migrate()
 
+def init_databases(app: Flask):
+    db.init_app(app)
+    # db.create_all(app=app)
+    ma.init_app(app)
+
+    if os.environ.get('WERKZEUG_RUN_MAIN') == 'true':
+        with app.app_context():
+            try:
+                db.engine.connect()
+                # 导入sql
+                # with app.open_resource('schema.sql', mode='r') as f:
+                #     db.cursor().executescript(f.read())
+                # db.commit()
+            except Exception as e:
+                exit(f"数据库连接失败: {e}")

+ 1 - 0
apps/views/__init__.py

@@ -4,6 +4,7 @@ from flask import Flask
 from .home import bp as home_bp
 
 def init_blueprints(app: Flask):
+    ''' init routes '''
     app.register_blueprint(v1_bp)
     app.register_blueprint(v2_bp)
     app.register_blueprint(home_bp)

+ 21 - 4
apps/views/api/v1.py

@@ -1,10 +1,19 @@
-from flask import Blueprint, request, jsonify, redirect, url_for, current_app, send_from_directory, make_response
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+'''
+@Contact :   liuyuqi.gov@msn.cn
+@Time    :   2023/11/09 14:21:52
+@License :   Copyright © 2017-2022 liuyuqi. All Rights Reserved.
+@Desc    :   api v1
+'''
+
+from flask import Blueprint, request, jsonify, current_app, send_from_directory, make_response
 import datetime
 import os
 import shutil
+from medical_assist import predict
 
 bp = Blueprint('v1', __name__, url_prefix='/api/v1')
-# get flask app
 
 def allowed_file(filename):
     ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG', 'bmp'])
@@ -17,8 +26,16 @@ def home():
                       'msg': "api/v1"
                     }))
 
-@bp.route('/upload', methods=['GET', 'POST'])
+@bp.route('/upload', methods=['GET','POST'])
 def upload_file():
+    ''' upload file '''
+
+    if 'file' not in request.files:
+        return jsonify({
+            'status': -1,
+            'msg': 'no file part'
+                        })
+    
     file = request.files['file']
     print(datetime.datetime.now(), file.filename)
     if file and allowed_file(file.filename):
@@ -27,7 +44,7 @@ def upload_file():
         shutil.copy(src_path, './tmp/ct')
         image_path = os.path.join('./tmp/ct', file.filename)
         print(src_path, image_path)
-        pid, image_info = paddlex.main.c_main(image_path, current_app.model)
+        pid, image_info = predict(image_path, current_app.pdx_model)
         return jsonify({'status': 1,
                         'image_url': 'http://127.0.0.1:5003/tmp/ct/' + pid,
                         'draw_url': 'http://127.0.0.1:5003/tmp/draw/' + pid,

+ 9 - 0
apps/views/api/v2.py

@@ -1,3 +1,12 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+'''
+@Contact :   liuyuqi.gov@msn.cn
+@Time    :   2023/11/09 14:22:03
+@License :   Copyright © 2017-2022 liuyuqi. All Rights Reserved.
+@Desc    :   api v2
+'''
+
 from flask import Blueprint, render_template, request, jsonify
 
 bp = Blueprint('v2', __name__, url_prefix='/api/v2')

+ 10 - 4
apps/views/home.py

@@ -1,7 +1,13 @@
-from flask import Blueprint, request, jsonify, redirect, url_for, current_app, send_from_directory, make_response
-import datetime
-import os
-import shutil
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+'''
+@Contact :   liuyuqi.gov@msn.cn
+@Time    :   2023/11/09 14:22:14
+@License :   Copyright © 2017-2022 liuyuqi. All Rights Reserved.
+@Desc    :   home, redirect to index.html
+'''
+
+from flask import Blueprint, redirect, url_for
 
 bp = Blueprint('home', __name__)
 

+ 9 - 0
docs/model_train.ipynb

@@ -1,5 +1,14 @@
 {
  "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## 训练模型\n",
+    "\n",
+    "基于 paddlex "
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,

+ 1 - 1
frontend/package.json

@@ -50,4 +50,4 @@
     "last 2 versions",
     "not ie <= 8"
   ]
-}
+}

+ 3 - 0
medical_assist/__init__.py

@@ -0,0 +1,3 @@
+from .assist import predict
+
+

+ 3 - 5
paddlex/paddlex.py → medical_assist/assist.py

@@ -1,22 +1,20 @@
 import numpy as np
 import cv2
+from .get_feature import main
 import os
-from .get_image_info import main
 
 rate = 0.5
 
 def predict(path, model):
     ''' predict
-    Args: path - image path
+    Args: 
+    path - image path
     Returns: image path, image info
     '''
     global img_y
     file_name = os.path.split(path)[1].split('.')[0]
     
     x = path.replace('\\', '/')
-    file_name = file_name
-    print(x)
-    print(file_name)
     img_y = model.predict(x)['label_map']
     img_y = img_y * 255
     img_y = img_y.astype(np.int)

+ 0 - 0
paddlex/get_feature.py → medical_assist/get_feature.py


+ 0 - 0
paddlex/__init__.py


+ 1 - 0
requirements.txt

@@ -1,5 +1,6 @@
 # paddlepaddle
 flask==2.3.2
 python-dotenv==1.0.0
+opencv-python
 
 

+ 17 - 0
test/flask_test.py

@@ -0,0 +1,17 @@
+
+from flask import Flask 
+
+app = Flask(__name__, template_folder="GUI")
+
+def init_hook(app: Flask):
+    ''' init hook '''
+
+    @app.after_request
+    def after_request(response):
+        ''' resolve cross-domain '''
+        response.headers["loww"] = "nas3333333e"
+        return response
+init_hook(app)
+
+app.run(debug=True)
+