食物照片识别卡路里(深度学习)
以下是一个基于 Python 的食物卡路里计算程序,使用深度学习模型识别食物种类,并通过食物数据库查询卡路里信息。程序使用 PyQt5 构建用户界面,支持加载食物图片、识别食物种类并计算卡路里。
主要功能说明:
1. 用户界面
- 主窗口:包含加载图片、计算卡路里、清除按钮,以及图片显示区域和结果展示区域。
- 图片显示:支持显示原始图片。
- 结果显示:显示识别到的食物种类和卡路里。
---
2. 核心功能
- 加载图片:用户可以通过点击“加载图片”按钮选择本地图片文件。
- 食物识别:使用 ResNet 模型识别图片中的食物种类。
- 卡路里计算:根据识别到的食物种类查询卡路里数据库,显示卡路里信息。
- 清除显示:点击“清除”按钮,清空图片和结果显示区域。
---
3. 技术细节
- ResNet 模型:使用 PyTorch 提供的预训练 ResNet 模型,并修改最后一层全连接层以输出 101 个类别(对应 Food-101 数据集)。
- 图片处理:使用 OpenCV 进行图片的加载和格式转换,并通过数据增强(如调整大小、随机翻转、归一化)提高模型鲁棒性。
- 食物数据库:包含常见食物的卡路里信息。
import sys
import cv2
import torch
import torchvision
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QHBoxLayout, QPushButton, QLabel, QFileDialog, QLineEdit)
from PyQt5.QtGui import QPixmap, QImage
from PyQt5.QtCore import Qt
from torchvision import transforms
class FoodCalorieCalculator(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("食物卡路里计算器")
self.setGeometry(100, 100, 800, 600)
# 初始化UI
self.init_ui()
# 加载预训练的 ResNet 模型
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
self.model.fc = torch.nn.Linear(self.model.fc.in_features, 101) # 修改输出类别数为 101
self.model.eval()
# Food-101 类别名称
self.food_classes = [
'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla',
'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder',
'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes',
'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots',
'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries',
'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt',
'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon',
'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros',
'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich',
'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette',
'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta',
'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich',
'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops',
'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara',
'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki',
'tiramisu', 'tuna_tartare', 'waffles'
]
# 食物卡路里数据库(示例)
self.calorie_data = {
'apple_pie': 237, 'hamburger': 295, 'pizza': 266, 'steak': 271, 'sushi': 200,
'fried_rice': 228, 'ramen': 436, 'pancakes': 227, 'waffles': 291, 'ice_cream': 207
}
def init_ui(self):
"""初始化用户界面"""
# 创建主窗口部件和布局
central_widget = QWidget()
self.setCentralWidget(central_widget)
layout = QVBoxLayout(central_widget)
# 创建顶部按钮区域
button_layout = QHBoxLayout()
# 添加按钮
self.btn_load = QPushButton("加载图片", self)
self.btn_load.clicked.connect(self.load_image)
button_layout.addWidget(self.btn_load)
self.btn_calculate = QPushButton("计算卡路里", self)
self.btn_calculate.clicked.connect(self.calculate_calories)
button_layout.addWidget(self.btn_calculate)
self.btn_clear = QPushButton("清除", self)
self.btn_clear.clicked.connect(self.clear_display)
button_layout.addWidget(self.btn_clear)
layout.addLayout(button_layout)
# 创建显示区域
display_layout = QHBoxLayout()
# 原始图片显示
self.image_label = QLabel()
self.image_label.setMinimumSize(400, 400)
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setStyleSheet("border: 2px solid black;")
display_layout.addWidget(self.image_label)
# 处理后的图片显示
self.processed_label = QLabel()
self.processed_label.setMinimumSize(400, 400)
self.processed_label.setAlignment(Qt.AlignCenter)
self.processed_label.setStyleSheet("border: 2px solid black;")
display_layout.addWidget(self.processed_label)
layout.addLayout(display_layout)
# 结果显示
self.result_label = QLabel("识别结果将在这里显示")
self.result_label.setAlignment(Qt.AlignCenter)
self.result_label.setStyleSheet("""
QLabel {
font-size: 24px;
margin: 20px;
padding: 10px;
background-color: #f0f0f0;
border-radius: 5px;
}
""")
layout.addWidget(self.result_label)
# 初始化变量
self.current_image = None
self.processed_image = None
def load_image(self):
"""加载图片"""
file_name, _ = QFileDialog.getOpenFileName(
self, "选择图片", "", "Image Files (*.png *.jpg *.jpeg *.bmp)"
)
if file_name:
# 读取图片
self.current_image = cv2.imread(file_name)
if self.current_image is None:
self.result_label.setText("无法加载图片!")
return
# 显示原始图片
self.display_image(self.current_image, self.image_label)
def display_image(self, image, label):
"""显示图片到指定的标签"""
height, width = image.shape[:2]
bytes_per_line = 3 * width
q_image = QImage(image.data, width, height, bytes_per_line, QImage.Format_RGB888).rgbSwapped()
pixmap = QPixmap.fromImage(q_image)
scaled_pixmap = pixmap.scaled(label.size(), Qt.KeepAspectRatio)
label.setPixmap(scaled_pixmap)
def calculate_calories(self):
"""计算食物卡路里"""
if self.current_image is None:
self.result_label.setText("请先加载图片!")
return
# 将OpenCV的BGR图片转换为RGB
image_rgb = cv2.cvtColor(self.current_image, cv2.COLOR_BGR2RGB)
# 数据增强
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 将NumPy数组转换为PyTorch Tensor
image_tensor = transform(image_rgb)
# 添加batch维度
image_tensor = image_tensor.unsqueeze(0)
# 使用模型进行食物识别
with torch.no_grad():
outputs = self.model(image_tensor)
_, predicted = torch.max(outputs, 1)
predicted_id = predicted.item()
# 检查类别 ID 是否在有效范围内
if 0 <= predicted_id < len(self.food_classes):
food_class = self.food_classes[predicted_id]
else:
self.result_label.setText("无法识别食物种类!")
return
# 查询卡路里
if food_class in self.calorie_data:
calories = self.calorie_data[food_class]
self.result_label.setText(f"识别结果: {food_class}\n卡路里: {calories} kcal")
else:
self.result_label.setText(f"识别结果: {food_class}\n未找到卡路里数据")
def clear_display(self):
"""清除显示"""
self.image_label.clear()
self.processed_label.clear()
self.result_label.setText("识别结果将在这里显示")
self.current_image = None
self.processed_image = None
def main():
app = QApplication(sys.argv)
window = FoodCalorieCalculator()
window.show()
sys.exit(app.exec_())
if __name__ == "__main__":
main()