cv与一些classify事项记录
# preface: 一些 project 记录,慢慢学吧调参炼丹
# 需要实现的功能是接收树莓派传过来的植物图像,对其是否病害进行分类预测,用的训练集是 [这个](病虫害分类数据集 - 飞桨 AI Studio (baidu.com)),训练集的话需要写一个加载 json 数据的方法
def create_image_folders_by_class(json_file_path, image_folder_path, output_folder_path): | |
with open(json_file_path, 'r') as f: | |
labels_data = json.load(f) | |
# 分割训练集和验证集 | |
train_data, val_data = train_test_split(labels_data, test_size=0.2, random_state=42) | |
for name, data in [('train', train_data), ('val', val_data)]: | |
for item in data: | |
class_folder = os.path.join(output_folder_path, name, str(item['disease_class'])) | |
os.makedirs(class_folder, exist_ok=True) | |
shutil.copy(os.path.join(image_folder_path, item['image_id']), os.path.join(class_folder, item['image_id'])) |
# 最开始使用的是 ResNet50,后面感觉正确率太低了会不会跟模型有关,改成 densenet169,效果是差不太多…
# 还发现一个小问题,数据集里面都是带病的植物 class,好像没有健康的…
# 于是乎感觉反正准确率偏低了,干脆根据置信度分一类专门来判别没病的植物么得,重写一些训模型的部分,log 一下平均置信度和方差,然后假装正则大概放一个置信度来判别
# 大概大概这样子吧
def predict_image(image_path): | |
img = Image.open(image_path) | |
img = data_transform(img).unsqueeze(0) | |
img = img.to(device) | |
output = model(img) | |
probs = torch.nn.functional.softmax(output, dim=1) | |
max_prob, preds = torch.max(probs, 1) | |
# print(max_prob) | |
threshold = 0.7 | |
if max_prob < threshold: | |
return 'Low confidence' | |
else: | |
return preds |