JAVA OpenAI 多模型混用难?统一 ModelResolver 架构设计

JAVA OpenAI 多模型混用难?统一 ModelResolver 架构设计

大家好,今天我们来聊聊如何在 Java 项目中更优雅地使用 OpenAI 的多个模型。随着 OpenAI 提供的模型种类越来越多,例如 gpt-3.5-turbogpt-4text-embedding-ada-002 等,我们常常需要在同一个项目中根据不同的任务选用不同的模型。直接在代码中硬编码模型名称,会导致代码冗余、难以维护,且缺乏灵活性。

因此,我们需要一种统一的模型解析方案,能够根据特定条件自动选择合适的模型,并提供统一的接口访问。这就是我们今天要讨论的 ModelResolver 架构。

痛点分析:直接使用的弊端

在深入 ModelResolver 架构之前,我们先来看看直接使用 OpenAI 模型可能遇到的问题。

假设我们有一个文本摘要服务,需要根据文本长度选择不同的模型。较短的文本可以使用 gpt-3.5-turbo,较长的文本则使用 gpt-4 以获得更好的效果。

public class TextSummarizer {

    private final OpenAI openAI;

    public TextSummarizer(OpenAI openAI) {
        this.openAI = openAI;
    }

    public String summarize(String text) {
        String modelName;
        if (text.length() <= 500) {
            modelName = "gpt-3.5-turbo";
        } else {
            modelName = "gpt-4";
        }

        CompletionRequest completionRequest = CompletionRequest.builder()
                .model(modelName)
                .prompt(text)
                .maxTokens(100)
                .build();

        return openAI.createCompletion(completionRequest).getChoices().get(0).getText();
    }
}

这段代码虽然简单,但存在以下问题:

  • 硬编码模型名称: 模型名称直接写死在代码中,修改模型需要修改代码并重新部署。
  • 逻辑分散: 选择模型的逻辑分散在各个业务方法中,如果多个服务都需要根据文本长度选择模型,则需要在多个地方重复编写相同的逻辑。
  • 缺乏扩展性: 如果需要增加新的模型选择策略(例如根据用户级别选择模型),则需要修改大量的代码。
  • 难以测试: 由于模型选择逻辑与业务逻辑耦合在一起,难以进行单元测试。

ModelResolver 架构设计

ModelResolver 架构的核心思想是将模型选择逻辑从业务代码中解耦出来,通过定义一组 ModelResolver 接口,实现不同的模型选择策略。业务代码只需要调用 ModelResolver 接口,即可获取合适的模型名称。

1. ModelResolver 接口

首先,我们定义一个 ModelResolver 接口,该接口负责根据输入参数解析出模型名称。

public interface ModelResolver {
    /**
     * 根据输入参数解析模型名称.
     *
     * @param input 输入参数
     * @return 模型名称
     */
    String resolve(Object input);
}

2. 多个 ModelResolver 实现

我们可以实现多个 ModelResolver 接口,分别对应不同的模型选择策略。

  • LengthBasedModelResolver: 根据文本长度选择模型。
public class LengthBasedModelResolver implements ModelResolver {

    private final String shortTextModel;
    private final String longTextModel;
    private final int maxLength;

    public LengthBasedModelResolver(String shortTextModel, String longTextModel, int maxLength) {
        this.shortTextModel = shortTextModel;
        this.longTextModel = longTextModel;
        this.maxLength = maxLength;
    }

    @Override
    public String resolve(Object input) {
        if (!(input instanceof String)) {
            throw new IllegalArgumentException("Input must be a string.");
        }
        String text = (String) input;
        return text.length() <= maxLength ? shortTextModel : longTextModel;
    }
}
  • UserLevelModelResolver: 根据用户级别选择模型。
public class UserLevelModelResolver implements ModelResolver {

    private final Map<String, String> modelMap;

    public UserLevelModelResolver(Map<String, String> modelMap) {
        this.modelMap = modelMap;
    }

    @Override
    public String resolve(Object input) {
        if (!(input instanceof String)) {
            throw new IllegalArgumentException("Input must be a string.");
        }
        String userLevel = (String) input;
        return modelMap.getOrDefault(userLevel, "default-model");
    }
}
  • CompositeModelResolver: 组合多个 ModelResolver,按照优先级依次尝试解析模型名称。
public class CompositeModelResolver implements ModelResolver {

    private final List<ModelResolver> resolvers;

    public CompositeModelResolver(List<ModelResolver> resolvers) {
        this.resolvers = resolvers;
    }

    @Override
    public String resolve(Object input) {
        for (ModelResolver resolver : resolvers) {
            String modelName = resolver.resolve(input);
            if (modelName != null && !modelName.isEmpty()) {
                return modelName;
            }
        }
        return null; // Or throw an exception if no model can be resolved
    }
}

3. 使用 ModelResolver

在业务代码中,我们只需要注入 ModelResolver 接口,即可获取模型名称。

public class TextSummarizer {

    private final OpenAI openAI;
    private final ModelResolver modelResolver;

    public TextSummarizer(OpenAI openAI, ModelResolver modelResolver) {
        this.openAI = openAI;
        this.modelResolver = modelResolver;
    }

    public String summarize(String text) {
        String modelName = modelResolver.resolve(text);

        CompletionRequest completionRequest = CompletionRequest.builder()
                .model(modelName)
                .prompt(text)
                .maxTokens(100)
                .build();

        return openAI.createCompletion(completionRequest).getChoices().get(0).getText();
    }
}

4. 配置 ModelResolver

我们可以通过配置文件(例如 application.propertiesapplication.yml)来配置 ModelResolver

例如,使用 Spring Boot 配置文件:

model.resolver.type: composite
model.resolver.length.short-text-model: gpt-3.5-turbo
model.resolver.length.long-text-model: gpt-4
model.resolver.length.max-length: 500
model.resolver.user-level.model-map.basic: gpt-3.5-turbo
model.resolver.user-level.model-map.premium: gpt-4
model.resolver.user-level.default-model: gpt-3.5-turbo

然后,我们可以编写一个配置类来创建 ModelResolver 实例。

import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@Configuration
public class ModelResolverConfig {

    @Value("${model.resolver.type}")
    private String resolverType;

    @Value("${model.resolver.length.short-text-model:gpt-3.5-turbo}")
    private String shortTextModel;

    @Value("${model.resolver.length.long-text-model:gpt-4}")
    private String longTextModel;

    @Value("${model.resolver.length.max-length:500}")
    private int maxLength;

    @Value("#{${model.resolver.user-level.model-map:{basic:gpt-3.5-turbo,premium:gpt-4}}}")
    private Map<String, String> userLevelModelMap;

    @Value("${model.resolver.user-level.default-model:gpt-3.5-turbo}")
    private String defaultUserLevelModel;

    @Bean
    public ModelResolver modelResolver() {
        if ("composite".equalsIgnoreCase(resolverType)) {
            return compositeModelResolver();
        } else if ("length".equalsIgnoreCase(resolverType)) {
            return lengthBasedModelResolver();
        } else if ("user-level".equalsIgnoreCase(resolverType)) {
            return userLevelModelResolver();
        }
        throw new IllegalArgumentException("Invalid resolver type: " + resolverType);
    }

    @Bean
    public LengthBasedModelResolver lengthBasedModelResolver() {
        return new LengthBasedModelResolver(shortTextModel, longTextModel, maxLength);
    }

    @Bean
    public UserLevelModelResolver userLevelModelResolver() {
        return new UserLevelModelResolver(userLevelModelMap);
    }

    @Bean
    public CompositeModelResolver compositeModelResolver() {
        List<ModelResolver> resolvers = Arrays.asList(lengthBasedModelResolver(), userLevelModelResolver());
        return new CompositeModelResolver(resolvers);
    }
}

5. Spring Boot 集成示例

import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.service.OpenAI;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class TextSummarizer {

    private final OpenAI openAI;
    private final ModelResolver modelResolver;

    @Autowired
    public TextSummarizer(OpenAI openAI, ModelResolver modelResolver) {
        this.openAI = openAI;
        this.modelResolver = modelResolver;
    }

    public String summarize(String text, String userLevel) {
        // 可以传入多个参数,例如文本内容和用户级别
        // 在 CompositeModelResolver 中,可以先根据文本长度选择模型,再根据用户级别进行调整
        String modelName = modelResolver.resolve(text);

        CompletionRequest completionRequest = CompletionRequest.builder()
                .model(modelName)
                .prompt(text)
                .maxTokens(100)
                .build();

        return openAI.createCompletion(completionRequest).getChoices().get(0).getText();
    }

    public String summarize(String text) {
         String modelName = modelResolver.resolve(text);

        CompletionRequest completionRequest = CompletionRequest.builder()
                .model(modelName)
                .prompt(text)
                .maxTokens(100)
                .build();

        return openAI.createCompletion(completionRequest).getChoices().get(0).getText();
    }
}

6. 扩展性考虑

ModelResolver 架构具有良好的扩展性。如果需要增加新的模型选择策略,只需要实现 ModelResolver 接口,并将其注册到 CompositeModelResolver 中即可。

例如,我们可以实现一个 ABTestModelResolver,根据 A/B 测试结果选择模型。

public class ABTestModelResolver implements ModelResolver {

    private final String modelA;
    private final String modelB;
    private final double probabilityA;

    public ABTestModelResolver(String modelA, String modelB, double probabilityA) {
        this.modelA = modelA;
        this.modelB = modelB;
        this.probabilityA = probabilityA;
    }

    @Override
    public String resolve(Object input) {
        double random = Math.random();
        return random < probabilityA ? modelA : modelB;
    }
}

然后,将 ABTestModelResolver 注册到 CompositeModelResolver 中。

@Bean
public ABTestModelResolver abTestModelResolver() {
    return new ABTestModelResolver("gpt-3.5-turbo", "gpt-4", 0.5);
}

@Bean
public CompositeModelResolver compositeModelResolver(LengthBasedModelResolver lengthBasedModelResolver, UserLevelModelResolver userLevelModelResolver, ABTestModelResolver abTestModelResolver) {
    List<ModelResolver> resolvers = Arrays.asList(lengthBasedModelResolver, userLevelModelResolver, abTestModelResolver);
    return new CompositeModelResolver(resolvers);
}

7. 单元测试

ModelResolver 架构使得单元测试更加容易。我们可以针对每个 ModelResolver 接口编写单元测试,验证其模型选择逻辑的正确性。

例如,针对 LengthBasedModelResolver 的单元测试:

import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;

public class LengthBasedModelResolverTest {

    @Test
    public void testResolveShortText() {
        LengthBasedModelResolver resolver = new LengthBasedModelResolver("gpt-3.5-turbo", "gpt-4", 500);
        String modelName = resolver.resolve("This is a short text.");
        assertEquals("gpt-3.5-turbo", modelName);
    }

    @Test
    public void testResolveLongText() {
        LengthBasedModelResolver resolver = new LengthBasedModelResolver("gpt-3.5-turbo", "gpt-4", 500);
        String modelName = resolver.resolve("This is a long text that exceeds the maximum length.");
        assertEquals("gpt-4", modelName);
    }

    @Test
    public void testInvalidInput() {
        LengthBasedModelResolver resolver = new LengthBasedModelResolver("gpt-3.5-turbo", "gpt-4", 500);
        assertThrows(IllegalArgumentException.class, () -> resolver.resolve(123));
    }
}

代码示例:完整的可运行示例

// 依赖添加:
// implementation("com.theokanning.openai:openai-java:0.17.0")
// implementation("org.springframework.boot:spring-boot-starter-web:3.0.2")
// maven
//
//        <dependency>
//            <groupId>com.theokanning.openai</groupId>
//            <artifactId>openai-java</artifactId>
//            <version>0.17.0</version>
//        </dependency>
//        <dependency>
//            <groupId>org.springframework.boot</groupId>
//            <artifactId>spring-boot-starter-web</artifactId>
//            <version>3.0.2</version>
//        </dependency>

import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.service.OpenAI;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Service;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

interface ModelResolver {
    String resolve(Object input);
}

class LengthBasedModelResolver implements ModelResolver {

    private final String shortTextModel;
    private final String longTextModel;
    private final int maxLength;

    public LengthBasedModelResolver(String shortTextModel, String longTextModel, int maxLength) {
        this.shortTextModel = shortTextModel;
        this.longTextModel = longTextModel;
        this.maxLength = maxLength;
    }

    @Override
    public String resolve(Object input) {
        if (!(input instanceof String)) {
            throw new IllegalArgumentException("Input must be a string.");
        }
        String text = (String) input;
        return text.length() <= maxLength ? shortTextModel : longTextModel;
    }
}

class UserLevelModelResolver implements ModelResolver {

    private final Map<String, String> modelMap;

    public UserLevelModelResolver(Map<String, String> modelMap) {
        this.modelMap = modelMap;
    }

    @Override
    public String resolve(Object input) {
        if (!(input instanceof String)) {
            throw new IllegalArgumentException("Input must be a string.");
        }
        String userLevel = (String) input;
        return modelMap.getOrDefault(userLevel, "default-model");
    }
}

class CompositeModelResolver implements ModelResolver {

    private final List<ModelResolver> resolvers;

    public CompositeModelResolver(List<ModelResolver> resolvers) {
        this.resolvers = resolvers;
    }

    @Override
    public String resolve(Object input) {
        for (ModelResolver resolver : resolvers) {
            String modelName = resolver.resolve(input);
            if (modelName != null && !modelName.isEmpty()) {
                return modelName;
            }
        }
        return null; // Or throw an exception if no model can be resolved
    }
}

@Service
class TextSummarizer {

    private final OpenAI openAI;
    private final ModelResolver modelResolver;

    public TextSummarizer(OpenAI openAI, ModelResolver modelResolver) {
        this.openAI = openAI;
        this.modelResolver = modelResolver;
    }

    public String summarize(String text, String userLevel) {

        String modelName = modelResolver.resolve(text);

        CompletionRequest completionRequest = CompletionRequest.builder()
                .model(modelName)
                .prompt(text)
                .maxTokens(100)
                .build();

        return openAI.createCompletion(completionRequest).getChoices().get(0).getText();
    }

    public String summarize(String text) {
        String modelName = modelResolver.resolve(text);

        CompletionRequest completionRequest = CompletionRequest.builder()
                .model(modelName)
                .prompt(text)
                .maxTokens(100)
                .build();

        return openAI.createCompletion(completionRequest).getChoices().get(0).getText();
    }
}

@Configuration
class ModelResolverConfig {

    @Value("${model.resolver.type}")
    private String resolverType;

    @Value("${model.resolver.length.short-text-model:gpt-3.5-turbo}")
    private String shortTextModel;

    @Value("${model.resolver.length.long-text-model:gpt-4}")
    private String longTextModel;

    @Value("${model.resolver.length.max-length:500}")
    private int maxLength;

    @Value("#{${model.resolver.user-level.model-map:{basic:gpt-3.5-turbo,premium:gpt-4}}}")
    private Map<String, String> userLevelModelMap;

    @Value("${model.resolver.user-level.default-model:gpt-3.5-turbo}")
    private String defaultUserLevelModel;

    @Bean
    public ModelResolver modelResolver() {
        if ("composite".equalsIgnoreCase(resolverType)) {
            return compositeModelResolver();
        } else if ("length".equalsIgnoreCase(resolverType)) {
            return lengthBasedModelResolver();
        } else if ("user-level".equalsIgnoreCase(resolverType)) {
            return userLevelModelResolver();
        }
        throw new IllegalArgumentException("Invalid resolver type: " + resolverType);
    }

    @Bean
    public LengthBasedModelResolver lengthBasedModelResolver() {
        return new LengthBasedModelResolver(shortTextModel, longTextModel, maxLength);
    }

    @Bean
    public UserLevelModelResolver userLevelModelResolver() {
        return new UserLevelModelResolver(userLevelModelMap);
    }

    @Bean
    public CompositeModelResolver compositeModelResolver() {
        List<ModelResolver> resolvers = Arrays.asList(lengthBasedModelResolver(), userLevelModelResolver());
        return new CompositeModelResolver(resolvers);
    }
}

@SpringBootApplication
@RestController
class DemoApplication {

    @Value("${openai.api.key}")
    private String openaiApiKey;

    private final TextSummarizer textSummarizer;

    public DemoApplication(TextSummarizer textSummarizer) {
        this.textSummarizer = textSummarizer;
    }

    @Bean
    public OpenAI openAI() {
        return new OpenAI(openaiApiKey);
    }

    @GetMapping("/summarize")
    public String summarize(@RequestParam String text) {
        return textSummarizer.summarize(text);
    }

    public static void main(String[] args) {
        SpringApplication.run(DemoApplication.class, args);
    }
}

application.properties 中配置:

openai.api.key=YOUR_OPENAI_API_KEY
model.resolver.type=composite
model.resolver.length.short-text-model=gpt-3.5-turbo
model.resolver.length.long-text-model=gpt-4
model.resolver.length.max-length=500
model.resolver.user-level.model-map.basic=gpt-3.5-turbo
model.resolver.user-level.model-map.premium=gpt-4
model.resolver.user-level.default-model=gpt-3.5-turbo

架构优势:解耦、灵活、易维护

通过 ModelResolver 架构,我们实现了模型选择逻辑与业务代码的解耦,提高了代码的灵活性、可维护性和可测试性。

  • 解耦: 模型选择逻辑不再分散在各个业务方法中,而是集中在 ModelResolver 接口的实现类中。
  • 灵活: 可以根据不同的需求实现不同的 ModelResolver 接口,例如根据文本长度、用户级别、A/B 测试结果等选择模型。
  • 易维护: 修改模型选择策略只需要修改 ModelResolver 接口的实现类或配置文件,而不需要修改大量的业务代码。
  • 易测试: 可以针对每个 ModelResolver 接口编写单元测试,验证其模型选择逻辑的正确性。

局限性:复杂性增加

虽然 ModelResolver 架构带来了诸多优势,但也增加了一些复杂性。需要设计和维护 ModelResolver 接口及其实现类,并配置 CompositeModelResolver。但是,这种复杂性是值得的,因为它带来了更好的代码结构和可维护性。

未来方向:更智能的模型选择

ModelResolver 架构为我们提供了一个统一的模型解析方案,但模型选择策略仍然需要人工定义。未来,我们可以探索更智能的模型选择方法,例如使用机器学习模型自动选择合适的模型。

例如,我们可以训练一个模型,根据输入文本的特征(例如主题、情感、复杂度等)来选择最合适的模型。

模型选择逻辑的抽象与复用

通过 ModelResolver 架构,我们将模型选择逻辑从业务代码中抽离出来,实现了模型选择逻辑的抽象与复用。这种架构不仅适用于 OpenAI 模型,也适用于其他需要根据不同条件选择不同服务的场景。它提供了一种清晰、可扩展的方式来管理和维护多个模型,降低了代码的复杂性,并提高了系统的灵活性。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注