Python实现自定义数据加载器:处理超大数据集与内存映射文件(mmap)

Python实现自定义数据加载器:处理超大数据集与内存映射文件(mmap)

大家好,今天我们来探讨一个在数据科学和机器学习领域非常重要的课题:如何有效地处理超大数据集。当数据集的大小超过了我们机器的物理内存容量时,传统的加载方式就显得力不从心。我们需要寻找更高效的方法,而内存映射文件(mmap)就是一种非常强大的工具。本文将深入讲解如何利用Python实现自定义数据加载器,并结合mmap技术来处理这类超大数据集。

1. 超大数据集带来的挑战

在深入代码之前,我们先来明确一下超大数据集带来的具体挑战:

  • 内存限制: 最直接的问题是内存容量不足。一次性将整个数据集加载到内存中是不可能的。
  • IO瓶颈: 频繁地从磁盘读取数据会成为性能瓶颈,因为磁盘IO的速度远低于内存访问速度。
  • 数据预处理: 对超大数据集进行预处理,例如清洗、转换和特征工程,同样需要高效的策略。

2. 内存映射文件(mmap)的概念

内存映射文件 (mmap) 是一种将文件内容映射到进程虚拟地址空间的技术。它允许程序像访问内存一样访问文件中的数据,而无需显式地进行读取或写入操作。操作系统负责在需要时将文件的一部分加载到内存中,并在不再需要时将其写回磁盘。

mmap的主要优势在于:

  • 节省内存: 数据并非一次性加载到内存中,而是按需加载,从而显著减少了内存占用。
  • 提高IO效率: 操作系统会缓存常用的数据块,减少了磁盘IO的次数。
  • 方便的数据共享: 多个进程可以同时映射同一个文件,实现高效的数据共享。

3. Python中的mmap模块

Python的mmap模块提供了对内存映射文件的支持。我们可以使用它来创建、读取和修改映射到内存的文件内容。

创建mmap对象:

import mmap

# 创建一个大小为1024字节的内存映射文件
with open("large_file.data", "wb") as f:
    f.write(b'' * 1024)  # 初始化文件内容,这里用空字节填充

with open("large_file.data", "r+b") as f:
    mm = mmap.mmap(f.fileno(), length=0) #length=0 映射整个文件

    #现在mm 就是一个 mmap 对象,可以像访问 bytearray 一样访问文件内容
    mm[0:5] = b"Hello"
    print(mm[0:5]) # 输出 b'Hello'

    mm.close()

代码解释:

  1. 首先,我们创建一个名为large_file.data的文件,并用空字节初始化其内容。 这是因为 mmap 需要一个已存在的文件。
  2. 然后,我们以读写二进制模式 (r+b) 打开该文件。
  3. 使用 mmap.mmap() 函数创建一个 mmap 对象。 f.fileno() 获取文件描述符,length=0表示映射整个文件。如果文件非常大,可以映射一部分,提高效率。
  4. 通过mm[0:5] = b"Hello",我们将文件的前5个字节修改为 "Hello"。
  5. 最后,我们关闭 mmap 对象。

常用mmap对象的方法:

方法 描述
close() 关闭 mmap 对象。
find(sub) 在映射的内存中查找子字符串 sub 的位置。
flush([offset, size]) 将对映射内存的更改刷新到磁盘。 offsetsize 指定要刷新的区域,如果省略,则刷新整个映射。
move(dst, src, count) count 个字节从 src 移动到 dst
read(size) 从当前位置读取 size 个字节。
readline() 从当前位置读取一行。
resize(newsize) 调整 mmap 对象的大小。 这会更改底层文件的大小。
seek(offset[, whence]) 移动文件指针到指定位置。 whence 可以是 os.SEEK_SET (从文件开头), os.SEEK_CUR (从当前位置), 或 os.SEEK_END (从文件末尾)。
size() 返回映射区域的大小。
tell() 返回当前文件指针的位置。
write(bytes) 从当前位置写入字节。
write_byte(byte) 从当前位置写入单个字节。

4. 自定义数据加载器:结合mmap处理超大数据集

现在,我们将创建一个自定义数据加载器,它利用 mmap 来高效地读取超大数据集。 假设我们的数据集是一个文本文件,其中每行代表一个数据样本。

数据格式示例:

1.0,2.0,3.0,4.0,5.0
6.0,7.0,8.0,9.0,10.0
11.0,12.0,13.0,14.0,15.0
...

自定义数据加载器代码:

import mmap
import numpy as np
import os

class LargeTextDataset:
    def __init__(self, filename, delimiter=",", dtype=np.float32):
        self.filename = filename
        self.delimiter = delimiter.encode() # 将分隔符转换为字节串
        self.dtype = dtype
        self.file_size = os.path.getsize(self.filename) # 获取文件大小
        self.mmap_obj = None
        self.line_offsets = []  # 存储每一行的起始位置

        self._build_index()

    def _build_index(self):
        """
        构建索引,记录每行的起始位置
        """
        with open(self.filename, "rb") as f: # 以二进制只读模式打开文件
            self.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) # 创建只读的 mmap 对象

            offset = 0
            self.line_offsets.append(offset) # 第一行的起始位置

            while offset < self.file_size:
                try:
                    next_newline = self.mmap_obj.find(b'n', offset) # 查找下一个换行符的位置
                    if next_newline == -1: # 如果找不到换行符,说明已经到达文件末尾
                        break
                    offset = next_newline + 1 # 更新偏移量到下一行的起始位置
                    self.line_offsets.append(offset) # 记录下一行的起始位置
                except ValueError as e:
                    print(f"Error building index at offset {offset}: {e}")
                    break

    def __len__(self):
        """
        返回数据集的样本数量
        """
        return len(self.line_offsets) - 1 # 减 1 因为最后一个偏移量是文件末尾

    def __getitem__(self, idx):
        """
        根据索引获取数据样本
        """
        if idx < 0 or idx >= len(self):
            raise IndexError("Index out of range")

        start = self.line_offsets[idx]
        end = self.line_offsets[idx + 1]

        line = self.mmap_obj[start:end].decode().strip() # 读取一行数据并解码
        values = np.fromstring(line, sep=self.delimiter.decode(), dtype=self.dtype) # 将字符串转换为 NumPy 数组

        return values

    def close(self):
        """
        关闭 mmap 对象
        """
        if self.mmap_obj:
            self.mmap_obj.close()

# Example Usage:
filename = "large_dataset.txt"
# Create a dummy large dataset
with open(filename, 'w') as f:
    for i in range(100000): # 100,000 lines
        f.write(','.join([str(j) for j in range(5)]) + 'n') # 5 values per line

dataset = LargeTextDataset(filename)

# Accessing data
print(f"Dataset size: {len(dataset)}")
print(f"First sample: {dataset[0]}")
print(f"1000th sample: {dataset[999]}")

dataset.close() # Important: Close the mmap object when finished

# Clean up the dummy file
os.remove(filename)

代码解释:

  1. __init__(self, filename, delimiter=",", dtype=np.float32):

    • 构造函数,接收文件名、分隔符和数据类型作为参数。
    • self.filename: 存储文件名。
    • self.delimiter: 存储分隔符(转换为字节串)。
    • self.dtype: 存储数据类型。
    • self.file_size: 存储文件大小。
    • self.mmap_obj: 存储 mmap 对象 (初始为 None)。
    • self.line_offsets: 存储每一行的起始位置 (一个列表)。
    • 调用 _build_index() 方法来构建索引。
  2. _build_index(self):

    • 构建索引,该索引记录了文件中每一行的起始位置。
    • 以二进制只读模式 ("rb") 打开文件。
    • 创建只读的 mmap 对象 (mmap.ACCESS_READ)。 只读模式可以防止意外修改文件内容。
    • 初始化 offset 为 0 (文件起始位置)。
    • 将第一行的起始位置 (0) 添加到 self.line_offsets 列表中。
    • 使用 while 循环遍历文件,直到到达文件末尾。
    • 在循环中,使用 self.mmap_obj.find(b'n', offset) 查找从当前 offset 开始的下一个换行符 (n) 的位置。
    • 如果找不到换行符 (返回 -1),则表示已到达文件末尾,循环结束。
    • 否则,将 offset 更新为下一个换行符的位置加 1 (即下一行的起始位置),并将新的 offset 添加到 self.line_offsets 列表中。
  3. __len__(self):

    • 返回数据集的样本数量,即 self.line_offsets 列表的长度减 1。 减 1 是因为最后一个偏移量代表文件末尾,而不是一个实际的样本。
  4. __getitem__(self, idx):

    • 根据索引 idx 获取数据样本。
    • 首先进行索引检查,确保 idx 在有效范围内。
    • self.line_offsets 列表中获取索引 idx 对应的起始位置 start 和索引 idx + 1 对应的结束位置 end
    • 使用 self.mmap_obj[start:end]mmap 对象中读取该行的数据,然后使用 .decode() 方法将其解码为字符串,并使用 .strip() 方法去除首尾的空白字符。
    • 使用 np.fromstring(line, sep=self.delimiter.decode(), dtype=self.dtype) 将字符串转换为 NumPy 数组。 sep 参数指定分隔符,需要先解码为字符串。
    • 返回 NumPy 数组。
  5. close(self):

    • 关闭 mmap 对象,释放资源。 这是一个很重要的步骤,应该在完成数据访问后调用。

关键点:

  • 索引构建: _build_index() 方法是关键。它预先扫描整个文件,并记录每一行的起始位置。这使得我们可以通过索引快速访问任何一行数据,而无需从头开始读取文件。
  • 按需加载: mmap 对象允许我们按需访问文件中的数据,而不是一次性将整个文件加载到内存中。
  • 只读模式: 为了防止意外修改文件内容,我们以只读模式 (mmap.ACCESS_READ) 创建 mmap 对象。
  • 资源释放: 记得在完成数据访问后调用 close() 方法关闭 mmap 对象,释放资源。

5. 性能优化技巧

虽然 mmap 已经是一种非常高效的技术,但我们仍然可以通过一些技巧来进一步优化性能:

  • 选择合适的数据类型: 使用合适的数据类型可以减少内存占用和提高计算速度。 例如,如果数据范围较小,可以使用 np.int8np.float16 代替 np.int64np.float64
  • 批量处理: 尽量一次性处理多个数据样本,而不是逐个处理。 例如,可以实现一个 __getitem__ 方法,它接收一个索引列表,并返回一个包含多个样本的 NumPy 数组。
  • 多线程/多进程: 利用多线程或多进程并行处理数据。 由于 mmap 对象可以在多个进程之间共享,因此可以使用多进程来加速数据预处理。
  • 避免不必要的解码: 尽量避免频繁地在字节串和字符串之间进行转换。 如果只需要比较或搜索字节串,可以直接操作字节串,而无需解码为字符串。

6. 与其他数据加载方法的比较

方法 优点 缺点 适用场景
传统的文件读取 简单易用 内存占用高,IO效率低 数据集较小,可以完全加载到内存中
mmap 内存占用低,IO效率高,支持随机访问 实现相对复杂 超大数据集,无法完全加载到内存中,需要随机访问
Dask 支持并行计算,可以处理大于内存的数据集 需要学习 Dask 的 API,可能存在额外的开销 复杂的数据处理流程,需要并行计算
PyArrow 高效的内存数据结构,支持零拷贝操作,可以与其他语言共享数据 需要安装 PyArrow 库,可能存在兼容性问题 需要与其他语言共享数据,或者需要高效的内存数据结构

7. 高效处理超大数据集的案例:图像数据

上面的例子是处理文本数据。现在我们来看一个处理图像数据的例子。 假设我们有一个包含大量图像的文件,每个图像都以固定的字节数存储。

import mmap
import numpy as np
import os
from PIL import Image

class LargeImageDataset:
    def __init__(self, filename, image_width, image_height, image_format="L"):
        """
        初始化 LargeImageDataset 类。

        Args:
            filename (str): 包含图像数据的文件名。
            image_width (int): 每个图像的宽度。
            image_height (int): 每个图像的高度。
            image_format (str): 图像格式,例如 "L" (灰度), "RGB" (彩色).
        """
        self.filename = filename
        self.image_width = image_width
        self.image_height = image_height
        self.image_format = image_format
        self.image_size = image_width * image_height  # 每个图像的字节数
        if image_format == "RGB":
            self.image_size *= 3
        self.file_size = os.path.getsize(self.filename)
        self.num_images = self.file_size // self.image_size # 计算图像数量
        self.mmap_obj = None

        # 打开文件并创建 mmap 对象
        with open(self.filename, "rb") as f:
            self.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)

    def __len__(self):
        """
        返回数据集中的图像数量。
        """
        return self.num_images

    def __getitem__(self, idx):
        """
        根据索引获取图像。

        Args:
            idx (int): 图像的索引。

        Returns:
            PIL.Image.Image: 图像对象。
        """
        if idx < 0 or idx >= len(self):
            raise IndexError("Index out of range")

        start = idx * self.image_size
        end = start + self.image_size
        image_data = self.mmap_obj[start:end] # 从 mmap 对象中读取图像数据

        # 创建 PIL 图像对象
        image = Image.frombytes(self.image_format, (self.image_width, self.image_height), image_data)
        return image

    def close(self):
        """
        关闭 mmap 对象。
        """
        if self.mmap_obj:
            self.mmap_obj.close()

# Example Usage:
filename = "large_image_dataset.data"
image_width = 64
image_height = 64
image_format = "L" # 灰度图像

# 创建一个虚拟的图像数据集
num_images = 1000
image_size = image_width * image_height
with open(filename, "wb") as f:
    for _ in range(num_images):
        # 创建一个随机的灰度图像数据
        image_data = np.random.randint(0, 256, size=image_size, dtype=np.uint8).tobytes()
        f.write(image_data)

dataset = LargeImageDataset(filename, image_width, image_height, image_format)

# 访问图像
print(f"Dataset size: {len(dataset)}")
image = dataset[0] # 获取第一张图像
image.show() # 显示图像

dataset.close()
os.remove(filename)

代码解释:

  1. __init__(self, filename, image_width, image_height, image_format="L"):

    • 初始化函数,接收文件名、图像宽度、图像高度和图像格式作为参数。
    • 计算每个图像的字节数 image_size,并根据图像格式进行调整。
    • 计算数据集中图像的数量 num_images
    • 打开文件并创建只读的 mmap 对象。
  2. __len__(self):

    • 返回数据集中图像的数量。
  3. __getitem__(self, idx):

    • 根据索引 idx 获取图像。
    • 计算图像数据在文件中的起始位置 start 和结束位置 end
    • mmap 对象中读取图像数据。
    • 使用 PIL.Image.frombytes() 函数从字节数据创建 PIL 图像对象。
    • 返回图像对象。
  4. close(self):

    • 关闭 mmap 对象。

在这个例子中,我们假设图像数据以原始字节的形式存储在文件中,并且每个图像的大小相同。 mmap 允许我们高效地访问这些图像数据,而无需将整个文件加载到内存中。 我们可以使用 PIL (Python Imaging Library) 或其他图像处理库来处理这些图像数据。

最后说几句

总而言之,mmap 是一种处理超大数据集的强大技术。 通过结合 mmap 和自定义数据加载器,我们可以高效地访问和处理无法完全加载到内存中的数据。 在实际应用中,我们需要根据具体的数据格式和处理需求,选择合适的优化策略。

更多IT精英技术系列讲座,到智猿学院

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注