Contact Information
    Contact Information
  • Telephone:13693115325
  • Wechat:liuyiliang100
  • Mailbox:quantumliu@pku.edu.cn
加入开发者微信群
加入开发者微信群

Data Processing

In the process of model training, we may encounter the problem of overfitting. One solution is to do data augmentation on the training data. By processing the data in a specific way, such as cropping, flipping, and adjusting the brightness of the image, the diversity of the samples is increased, thereby enhancing the generalization ability of the model.

1. Introduction to tensorlayerx.vision.transforms

TensorLayerX framework has built-in dozens of image data processing methods in tensorlayerx.vision.transforms. The following code can be used to view them:

import tensorlayerx

print('Image data processing methods: ', tensorlayerx.vision.transforms.__all__)

Image data processing methods: ['Crop', 'CentralCrop', 'HsvToRgb', 'AdjustBrightness', 'AdjustContrast', 'AdjustHue', 'AdjustSaturation', 'FlipHorizontal', 'FlipVertical', 'RgbToGray', 'PadToBoundingbox', 'Pad', 'Normalize', 'StandardizePerImage', 'RandomBrightness', 'RandomContrast', 'RandomHue', 'RandomSaturation', 'RandomCrop', 'Resize', 'RgbToHsv', 'Transpose', 'Rotation', 'RandomRotation', 'RandomShift', 'RandomShear', 'RandomZoom', 'RandomFlipVertical', 'RandomFlipHorizontal', 'HWC2CHW', 'CHW2HWC', 'ToTensor', 'Compose', 'RandomResizedCrop', 'RandomAffine', 'ColorJitter', 'Rotation']

Including common operations such as image random cropping, image rotation transformation, changing image brightness, changing image contrast, etc. The introduction of each operation method can be found in the API documentation.

For the data preprocessing methods built-in in the TensorlayerX framework, they can be called individually or combined to use multiple data preprocessing methods. The specific usage is as follows:

  • Use individually

from tensorlayerx.vision.transforms import Resize

 

transform = Resize(size = (100,100), interpolation='bilinear')

  • Use multiple combinations

In this case, we need to define each data processing method first, and then use Compose to combine them.

from tensorlayerx.vision.transforms import (

    Compose, Resize, RandomFlipHorizontal, RandomContrast, RandomBrightness, StandardizePerImage, RandomCrop

)

 

transforms = Compose(

    [

        RandomCrop(size=[24, 24]),

        RandomFlipHorizontal(),

        RandomBrightness(brightness_factor=(0.5, 1.5)),

        RandomContrast(contrast_factor=(0.5, 1.5)),

        StandardizePerImage()

    ]

)

2. Apply data preprocessing operations in the dataset

After defining the data processing method, it can be directly applied in the dataset Dataset. The following introduces the application of data preprocessing in the custom dataset.

For custom datasets, the defined data processing method can be passed into the __init__ function in the dataset, and defined as an attribute of the custom dataset class. Then apply it to the image in __getitem__, as shown in the following code:

# TensorLayerX will automatically downloads and loads the MNIST dataset

print('download training data and load training data')

 

X_train, y_train, X_val, y_val, X_test, y_test = tlx.files.load_mnist_dataset(shape=(-1, 28, 28, 1))

X_train = X_train * 255

 

print('load finished')

 

 

class MNISTDataset(Dataset):

    """

    Step 1: Inherit the tensorlayerx.dataflow.Dataset class

    """

 

    def __init__(self, data=X_train, label=y_train, transform=transform):

        """

        Step 2: Implement the __init__ function to initialize the dataset and map the samples and labels to the list

        """

 

        self.data = data

        self.label = label

        self.transform = transform

 

    def __getitem__(self, index):

        """

        Step 3: Implement the __getitem__ function to define how to get data at the specified index and return a single data (sample data, corresponding label)

        """

 

        data = self.data[index].astype('float32')

        data = self.transform(data)

        label = self.label[index].astype('int64')

 

        return data, label

 

    def __len__(self):

        """

        Step 4: Implement the __len__ function to return the total number of samples in the dataset

        """

 

        return len(self.data)

 

transform = Compose([Normalize(mean=[127.5], std=[127.5], data_format='HWC')])

train_dataset = MNISTDataset(data=X_train, label=y_train, transform=transform)

 

3. Introduction to several data preprocessing methods

The effect of the data processing method built in TensorLayerX can be easily compared by visualization. The following introduces a comparison example of several methods.

First, download the example image

# Download example image

wget https://paddle-imagenet-models-name.bj.bcebos.com/data/demo_images/flower_demo.png

CentralCrop

Crop the input image and keep the center point of the image unchanged.

import cv2

import numpy as np

from PIL import Image

from matplotlib import pyplot as plt

from tensorlayerx.vision.transforms import CentralCrop

 

transform = CentralCrop(size = (224, 224))

 

image = cv2.imread('images/flower_demo.png')

 

image_after_transform = transform(image)

plt.subplot(1,2,1)

plt.title('origin image')

plt.imshow(image[:,:,::-1])

plt.subplot(1,2,2)

plt.title('CenterCrop image')

plt.imshow(image_after_transform[:,:,::-1])

flower_centercrop

RandomFlipHorizontal

Flip the image horizontally based on the random probability.

import numpy as np

from PIL import Image

from matplotlib import pyplot as plt

from tensorlayerx.vision.transforms import RandomFlipHorizontal

 

transform = RandomFlipHorizontal(0.5)

 

image = cv2.imread('images/flower_demo.png')

 

image_after_transform = transform(image)

plt.subplot(1,2,1)

plt.title('origin image')

plt.imshow(image[:,:,::-1])

plt.subplot(1,2,2)

plt.title('RandomFlipHorizontal image')

plt.imshow(image_after_transform[:,:,::-1])

flower_flip

ColorJitter

Adjust the brightness, contrast, saturation and hue of the image randomly.

import numpy as np

from PIL import Image

from matplotlib import pyplot as plt

from tensorlayerx.vision.transforms import ColorJitter

 

transform = ColorJitter(brightness=(1,5), contrast=(1,5), saturation=(1,5), hue=(-0.2,0.2))

 

image = cv2.imread('images/flower_demo.png')

 

image_after_transform = transform(image)

plt.subplot(1,2,1)

plt.title('origin image')

plt.imshow(image[:,:,::-1])

plt.subplot(1,2,2)

plt.title('ColorJitter image')

plt.imshow(image_after_transform[:,:,::-1])

flower_color

More data processing method introduction can refer to tensorlayerx.vision.transforms API documentation.