复制 from flask import Flask, request, send_file from model import * import os import torch import cv2 # Use CUDA os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_LAUNCH_BLOCKING'] = '1' use_cuda = torch.cuda.is_available() app = Flask(__name__) # 载入模型 ResNet2d = MutilResNet2dModel(image_height=512, image_width=512, image_channel=1, numclass=2, batch_size=16, loss_name='MutilCrossEntropyLoss', model_name='resnet101', accum_iter=1, use_cuda=use_cuda, inference=True, model_path=r'log/resnet101/ce/MutilResNet2d.pth') root_dir = r"D:/uploads" if not os.path.exists(root_dir): os.makedirs(root_dir) # 定义服务接口 @app.route('/predict', methods=['POST']) def predict(): file = request.files.get('file') # 获取上传的文件 if file: file.save(root_dir + '/' + file.filename) # 将上传文件保存到本地 image = cv2.imread(root_dir + '/' + file.filename, 0) # 读取本地文件 mask, mask_prob = ResNet2d.inference(image) # 对本地文件进行推理计算 # 返回预测结果 return f'category,{str(mask)}' else: return 'No file uploaded' # 定义服务接口 @app.route('/getresult', methods=['GET']) def getresult(): filename = request.args.get('file') # 获取请求参数中的文件名 if not filename: return "Missing parameter: file" # 没有提供文件名 filepath = root_dir + '/' + filename # 生成完整的文件路径 try: return send_file(filepath, as_attachment=True, attachment_filename=filename) except FileNotFoundError: return "The file does not exist" # 文件不存在 if __name__ == '__main__': app.run(host='0.0.0.0', port=8000) (责任编辑:) |