在CentOS上使用PyTorch进行数据处理的指南

774
2025/3/9 21:31:42
栏目: 智能运维
开发者测试专用服务器限时活动,0元免费领,库存有限,领完即止! 点击查看>>

在CentOS上使用PyTorch进行数据处理,首先需要确保系统上安装了合适的Python版本和PyTorch。以下是详细的步骤指南:

安装Python

  1. 更新系统
sudo yum update -y
  1. 安装Python
sudo yum install python3 python3-pip
  1. 验证Python安装
python3 --version

安装PyTorch

  1. 安装Miniconda(推荐):
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
  1. 创建并激活虚拟环境
conda create -n torch_env python=3.8
conda activate torch_env
  1. 安装PyTorch
  • CPU版本
conda install pytorch torchvision torchaudio cpuonly -c pytorch
  • GPU版本(需要CUDA):
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch

请将11.3替换为你系统上安装的CUDA版本。

  1. 验证PyTorch安装
python -c "import torch; print(torch.__version__)"

数据处理

  1. 加载系统数据集
import torch
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
  1. 创建自定义数据集
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
  1. 迭代和可视化数据集
import matplotlib.pyplot as plt

labels_map = {0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot"}

figure, axes = plt.subplots(3, 3, figsize=(8, 8))
for i in range(1, 9):
    sample_idx = torch.randint(len(train_data), size=1).item()
    img, label = train_data[sample_idx]
    axes[i // 3, i % 3].imshow(img.squeeze(), cmap='gray')
    axes[i // 3, i % 3].set_title(labels_map[label])
    axes[i // 3, i % 3].axis("off")
plt.show()
  1. 使用DataLoaders处理数据
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=2)

for images, labels in train_dataloader:
    print(f"Feature batch shape: {images.size()}")
    print(f"Labels batch shape: {labels.size()}")
    break

以上步骤涵盖了在CentOS上安装PyTorch以及进行数据处理的基本流程。确保系统环境配置正确,使用合适的命令安装PyTorch,并通过示例代码展示数据处理的基本操作。

辰迅云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>

推荐阅读: centos apache2模块有哪些