JS `Federated Learning` `Algorithms` (`Federated Averaging`) 在浏览器中的实现

各位前端的靓仔们,晚上好!我是你们的野生前端架构师,今儿咱不聊 Vue,不唠 React,咱们来点刺激的——浏览器里的 Federated Learning!

啥?你问我 Federated Learning 是啥?简单来说,就是“数据不出门,模型来回跑”。想象一下,你有一堆数据,我有一堆数据,咱俩都不想给对方看,但又想一起训练出一个牛逼的模型。 Federated Learning 就是干这个的!

今天,咱们就来聊聊如何在浏览器里实现 Federated Averaging,也就是联邦学习里最基础也最常用的算法。准备好了吗?发车!

第一站:理论先行,知其所以然

在咱们撸代码之前,先简单回顾一下 Federated Averaging (FedAvg) 的核心思想:

  1. 数据本地化: 每个客户端(在这里就是每个浏览器)都拥有自己的数据集,数据不会上传到中央服务器。
  2. 本地训练: 每个客户端使用自己的数据,在本地训练一个模型。
  3. 参数聚合: 客户端将训练好的模型参数(比如神经网络的权重)上传到中央服务器。
  4. 全局平均: 中央服务器对所有客户端上传的参数进行平均,得到一个全局模型。
  5. 模型分发: 中央服务器将全局模型分发给所有客户端,客户端用这个模型初始化自己的模型,然后重复步骤 2-5。

这个过程就像是大家一起做一道菜,每个人都在自己家做一部分,然后把做好的部分送到中央厨房,中央厨房把大家做的东西混合一下,再分发给大家,让大家继续改进。

第二站:环境搭建,准备开工

要在浏览器里搞 Federated Learning,我们需要一些工具:

  • TensorFlow.js: Google 出品的 JavaScript 机器学习库,让咱们可以在浏览器里玩转神经网络。
  • 一个简单的 Web 服务器: 用于分发模型和协调客户端。可以用 Node.js 搞一个,也可以直接用 Python 的 http.server 模块。
  • 浏览器: 废话!没有浏览器怎么跑 JavaScript?

先装好 TensorFlow.js:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>

第三站:模型设计,打好地基

咱们先来设计一个简单的模型,就用一个线性回归模型,简单粗暴,方便理解:

async function createModel() {
  const model = tf.sequential();
  model.add(tf.layers.dense({units: 1, inputShape: [1]}));
  model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
  return model;
}

这段代码使用 TensorFlow.js 创建了一个包含一个 Dense 层的线性回归模型。 units: 1 表示输出维度为 1, inputShape: [1] 表示输入维度为 1。 loss: 'meanSquaredError' 指定了损失函数为均方误差, optimizer: 'sgd' 指定了优化器为随机梯度下降。

第四站:客户端实现,各司其职

接下来,咱们来实现客户端的代码,也就是在浏览器里运行的代码:

class FederatedClient {
  constructor(serverUrl, clientId) {
    this.serverUrl = serverUrl;
    this.clientId = clientId;
    this.model = null;
    this.xs = null;
    this.ys = null;
  }

  async loadData() {
    // 模拟加载本地数据
    this.xs = tf.tensor2d([[1], [2], [3], [4]], [4, 1]);
    this.ys = tf.tensor2d([[2], [4], [6], [8]], [4, 1]);
  }

  async initModel() {
    this.model = await createModel();
  }

  async trainModel() {
    await this.model.fit(this.xs, this.ys, {epochs: 10});
    console.log(`Client ${this.clientId} training complete.`);
  }

  async sendModel() {
    const weights = this.model.getWeights();
    const weightsData = await Promise.all(weights.map(w => w.data()));
    const payload = {
      clientId: this.clientId,
      weights: weightsData,
      shapes: weights.map(w => w.shape)
    };

    await fetch(`${this.serverUrl}/upload`, {
      method: 'POST',
      headers: {
        'Content-Type': 'application/json'
      },
      body: JSON.stringify(payload)
    });
    console.log(`Client ${this.clientId} sent model to server.`);
  }

  async getGlobalModel() {
    const response = await fetch(`${this.serverUrl}/model`);
    const data = await response.json();
    const weights = data.weights.map((w, i) => tf.tensor(w, data.shapes[i]));
    this.model.setWeights(weights);
    console.log(`Client ${this.clientId} received global model.`);
  }

  async run() {
    await this.loadData();
    await this.initModel();
    await this.getGlobalModel(); // Get initial global model
    await this.trainModel();
    await this.sendModel();
  }
}

这段代码定义了一个 FederatedClient 类,它负责:

  • 加载本地数据: loadData() 方法模拟加载本地数据,这里用的是简单的线性数据。
  • 初始化模型: initModel() 方法调用 createModel() 函数创建一个模型。
  • 训练模型: trainModel() 方法使用本地数据训练模型。
  • 发送模型: sendModel() 方法将模型的权重发送到中央服务器。注意,这里需要将 TensorFlow.js 的 Tensor 对象转换为 JavaScript 的数组,才能通过 JSON 传输。
  • 获取全局模型: getGlobalModel() 方法从中央服务器获取全局模型,并用它来更新本地模型。

第五站:服务器实现,运筹帷幄

接下来,咱们来实现中央服务器的代码,这里用 Node.js 来实现:

const express = require('express');
const bodyParser = require('body-parser');
const tf = require('@tensorflow/tfjs-node');

const app = express();
const port = 3000;

app.use(bodyParser.json({limit: '50mb'})); // Allow larger payloads

let globalModel = null;
let clientWeights = {};
let clientShapes = {};
let clientCounts = {};

async function createInitialModel() {
    const model = tf.sequential();
    model.add(tf.layers.dense({units: 1, inputShape: [1]}));
    model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
    // Initialize weights with random values
    const xs = tf.tensor2d([[0]], [1, 1]);
    const ys = tf.tensor2d([[0]], [1, 1]);
    await model.fit(xs, ys, {epochs: 1});
    return model;
}

async function averageWeights(weights, shapes, counts) {
    const averagedWeights = [];
    for (let i = 0; i < weights[0].length; i++) { // Iterate through each layer's weights
        const layerWeights = weights.map(clientWeight => clientWeight[i]);
        const layerShapes = shapes.map(clientShape => clientShape[i]);

        // Ensure all clients have the same shape for the layer
        if (!layerShapes.every(shape => JSON.stringify(shape) === JSON.stringify(layerShapes[0]))) {
            console.error("Shapes are inconsistent across clients for layer:", i);
            return null; // Handle the inconsistency appropriately
        }

        let sum = tf.zeros(layerShapes[0]);
        for (let j = 0; j < weights.length; j++) {
            const tensor = tf.tensor(layerWeights[j], layerShapes[j]);
            const scaledTensor = tensor.mul(counts[j]); // Weight by number of samples
            sum = sum.add(scaledTensor);
            tensor.dispose(); // Dispose of the tensor to prevent memory leaks
            scaledTensor.dispose();
        }
        const totalSamples = counts.reduce((a, b) => a + b, 0);
        const average = sum.div(totalSamples);
        averagedWeights.push(await average.data());
        sum.dispose();
        average.dispose();
    }
    return averagedWeights;
}

app.post('/upload', async (req, res) => {
  const { clientId, weights, shapes } = req.body;
  clientWeights[clientId] = weights;
  clientShapes[clientId] = shapes;
  clientCounts[clientId] = 4; // Assume each client has 4 samples

  console.log(`Received model from client ${clientId}`);

  // Check if we have weights from all clients (e.g., 2 clients)
  if (Object.keys(clientWeights).length === 2) {
    const weightsArray = Object.values(clientWeights);
    const shapesArray = Object.values(clientShapes);
    const countsArray = Object.values(clientCounts);

    const averagedWeights = await averageWeights(weightsArray, shapesArray, countsArray);

    if (averagedWeights) {
      const newWeights = averagedWeights.map((w, i) => tf.tensor(w, shapesArray[0][i]));
      globalModel.setWeights(newWeights);

      // Clear client weights after averaging
      clientWeights = {};
      clientShapes = {};
      clientCounts = {};
      console.log('Global model updated.');
    } else {
      console.error('Failed to average weights due to inconsistent shapes.');
    }
  }
  res.sendStatus(200);
});

app.get('/model', async (req, res) => {
  if (!globalModel) {
    globalModel = await createInitialModel();
  }
  const weights = globalModel.getWeights();
  const weightsData = await Promise.all(weights.map(w => w.data()));
  const shapes = weights.map(w => w.shape);

  res.json({ weights: weightsData, shapes: shapes });
});

async function startServer() {
  globalModel = await createInitialModel(); // Initialize the global model
  app.listen(port, () => {
    console.log(`Server listening at http://localhost:${port}`);
  });
}

startServer();

这段代码使用 Express.js 创建了一个 Web 服务器,它负责:

  • 接收模型: /upload 路由接收客户端上传的模型权重。
  • 聚合参数: averageWeights() 函数对所有客户端上传的权重进行平均。这里需要将 JavaScript 的数组转换为 TensorFlow.js 的 Tensor 对象,才能进行计算。
  • 分发模型: /model 路由将全局模型分发给客户端。

第六站:整合测试,见证奇迹

现在,咱们把客户端和服务端整合起来,进行测试:

  1. 启动 Node.js 服务器: node server.js
  2. 在两个浏览器窗口中分别打开一个 HTML 页面,页面中包含以下代码:
<!DOCTYPE html>
<html>
<head>
  <title>Federated Learning Client</title>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
</head>
<body>
  <h1>Federated Learning Client</h1>
  <script>
    const serverUrl = 'http://localhost:3000';
    const clientId = Math.random().toString(36).substring(7); // Generate a random client ID

    const client = new FederatedClient(serverUrl, clientId);

    async function runClient() {
      await client.run();
    }

    runClient();
  </script>
  <script>
    class FederatedClient {
      constructor(serverUrl, clientId) {
        this.serverUrl = serverUrl;
        this.clientId = clientId;
        this.model = null;
        this.xs = null;
        this.ys = null;
      }

      async loadData() {
        // 模拟加载本地数据
        this.xs = tf.tensor2d([[1], [2], [3], [4]], [4, 1]);
        this.ys = tf.tensor2d([[2], [4], [6], [8]], [4, 1]);
      }

      async initModel() {
        this.model = await createModel();
      }

      async trainModel() {
        await this.model.fit(this.xs, this.ys, {epochs: 10});
        console.log(`Client ${this.clientId} training complete.`);
      }

      async sendModel() {
        const weights = this.model.getWeights();
        const weightsData = await Promise.all(weights.map(w => w.data()));
        const payload = {
          clientId: this.clientId,
          weights: weightsData,
          shapes: weights.map(w => w.shape)
        };

        await fetch(`${this.serverUrl}/upload`, {
          method: 'POST',
          headers: {
            'Content-Type': 'application/json'
          },
          body: JSON.stringify(payload)
        });
        console.log(`Client ${this.clientId} sent model to server.`);
      }

      async getGlobalModel() {
        const response = await fetch(`${this.serverUrl}/model`);
        const data = await response.json();
        const weights = data.weights.map((w, i) => tf.tensor(w, data.shapes[i]));
        this.model.setWeights(weights);
        console.log(`Client ${this.clientId} received global model.`);
      }

      async run() {
        await this.loadData();
        await this.initModel();
        await this.getGlobalModel(); // Get initial global model
        await this.trainModel();
        await this.sendModel();
      }
    }

    async function createModel() {
      const model = tf.sequential();
      model.add(tf.layers.dense({units: 1, inputShape: [1]}));
      model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
      return model;
    }
  </script>
</body>
</html>
  1. 打开浏览器的开发者工具,观察控制台的输出。你应该能看到客户端训练模型、发送模型、接收全局模型等信息。
  2. 重复运行客户端代码多次,观察模型的训练效果。你会发现,随着训练的进行,全局模型的性能会越来越好。

第七站:优化升级,更上一层楼

虽然咱们已经成功实现了 Federated Averaging,但这只是一个简单的示例。在实际应用中,还需要考虑很多问题:

  • 数据异构性: 不同的客户端可能拥有不同分布的数据。
  • 客户端选择: 不是所有客户端都适合参与训练,需要选择合适的客户端。
  • 安全隐私: 需要采取措施保护客户端的数据隐私。
  • 通信效率: 需要优化客户端和服务器之间的通信效率。

这里列一些可以考虑的优化方向,大家可以根据自己的需求进行选择:

优化方向 描述
数据增强 通过数据增强技术来缓解数据异构性。
差分隐私 在模型训练过程中加入噪声,保护客户端的数据隐私。
模型压缩 使用模型压缩技术来减少模型的大小,提高通信效率。
异步联邦学习 允许客户端异步地训练和上传模型,提高训练效率。
客户端采样 选择一部分客户端参与每一轮的训练,可以减少计算负担,并且可以处理客户端掉线的情况。

总结:学无止境,未来可期

恭喜各位!咱们已经成功地在浏览器里实现了 Federated Averaging。 虽然这只是一个起点,但它为咱们打开了一扇通往 Federated Learning 世界的大门。 希望大家能够继续探索,在实际项目中应用 Federated Learning,为构建更安全、更智能的未来贡献力量!

记住,前端er 也是可以搞机器学习的! 冲鸭!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注