在CentOS上使用PyTorch进行数据处理,首先需要确保系统上安装了合适的Python版本和PyTorch。以下是详细的步骤指南:
sudo yum update -y
sudo yum install python3 python3-pip
python3 --version
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
conda create -n torch_env python=3.8
conda activate torch_env
conda install pytorch torchvision torchaudio cpuonly -c pytorch
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
请将11.3
替换为你系统上安装的CUDA版本。
python -c "import torch; print(torch.__version__)"
import torch
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
import matplotlib.pyplot as plt
labels_map = {0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot"}
figure, axes = plt.subplots(3, 3, figsize=(8, 8))
for i in range(1, 9):
sample_idx = torch.randint(len(train_data), size=1).item()
img, label = train_data[sample_idx]
axes[i // 3, i % 3].imshow(img.squeeze(), cmap='gray')
axes[i // 3, i % 3].set_title(labels_map[label])
axes[i // 3, i % 3].axis("off")
plt.show()
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=2)
for images, labels in train_dataloader:
print(f"Feature batch shape: {images.size()}")
print(f"Labels batch shape: {labels.size()}")
break
以上步骤涵盖了在CentOS上安装PyTorch以及进行数据处理的基本流程。确保系统环境配置正确,使用合适的命令安装PyTorch,并通过示例代码展示数据处理的基本操作。
辰迅云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>
推荐阅读: centos apache2模块有哪些