JAVA OpenAI 多模型混用难?统一 ModelResolver 架构设计
大家好,今天我们来聊聊如何在 Java 项目中更优雅地使用 OpenAI 的多个模型。随着 OpenAI 提供的模型种类越来越多,例如 gpt-3.5-turbo、gpt-4、text-embedding-ada-002 等,我们常常需要在同一个项目中根据不同的任务选用不同的模型。直接在代码中硬编码模型名称,会导致代码冗余、难以维护,且缺乏灵活性。
因此,我们需要一种统一的模型解析方案,能够根据特定条件自动选择合适的模型,并提供统一的接口访问。这就是我们今天要讨论的 ModelResolver 架构。
痛点分析:直接使用的弊端
在深入 ModelResolver 架构之前,我们先来看看直接使用 OpenAI 模型可能遇到的问题。
假设我们有一个文本摘要服务,需要根据文本长度选择不同的模型。较短的文本可以使用 gpt-3.5-turbo,较长的文本则使用 gpt-4 以获得更好的效果。
public class TextSummarizer {
private final OpenAI openAI;
public TextSummarizer(OpenAI openAI) {
this.openAI = openAI;
}
public String summarize(String text) {
String modelName;
if (text.length() <= 500) {
modelName = "gpt-3.5-turbo";
} else {
modelName = "gpt-4";
}
CompletionRequest completionRequest = CompletionRequest.builder()
.model(modelName)
.prompt(text)
.maxTokens(100)
.build();
return openAI.createCompletion(completionRequest).getChoices().get(0).getText();
}
}
这段代码虽然简单,但存在以下问题:
- 硬编码模型名称: 模型名称直接写死在代码中,修改模型需要修改代码并重新部署。
- 逻辑分散: 选择模型的逻辑分散在各个业务方法中,如果多个服务都需要根据文本长度选择模型,则需要在多个地方重复编写相同的逻辑。
- 缺乏扩展性: 如果需要增加新的模型选择策略(例如根据用户级别选择模型),则需要修改大量的代码。
- 难以测试: 由于模型选择逻辑与业务逻辑耦合在一起,难以进行单元测试。
ModelResolver 架构设计
ModelResolver 架构的核心思想是将模型选择逻辑从业务代码中解耦出来,通过定义一组 ModelResolver 接口,实现不同的模型选择策略。业务代码只需要调用 ModelResolver 接口,即可获取合适的模型名称。
1. ModelResolver 接口
首先,我们定义一个 ModelResolver 接口,该接口负责根据输入参数解析出模型名称。
public interface ModelResolver {
/**
* 根据输入参数解析模型名称.
*
* @param input 输入参数
* @return 模型名称
*/
String resolve(Object input);
}
2. 多个 ModelResolver 实现
我们可以实现多个 ModelResolver 接口,分别对应不同的模型选择策略。
- LengthBasedModelResolver: 根据文本长度选择模型。
public class LengthBasedModelResolver implements ModelResolver {
private final String shortTextModel;
private final String longTextModel;
private final int maxLength;
public LengthBasedModelResolver(String shortTextModel, String longTextModel, int maxLength) {
this.shortTextModel = shortTextModel;
this.longTextModel = longTextModel;
this.maxLength = maxLength;
}
@Override
public String resolve(Object input) {
if (!(input instanceof String)) {
throw new IllegalArgumentException("Input must be a string.");
}
String text = (String) input;
return text.length() <= maxLength ? shortTextModel : longTextModel;
}
}
- UserLevelModelResolver: 根据用户级别选择模型。
public class UserLevelModelResolver implements ModelResolver {
private final Map<String, String> modelMap;
public UserLevelModelResolver(Map<String, String> modelMap) {
this.modelMap = modelMap;
}
@Override
public String resolve(Object input) {
if (!(input instanceof String)) {
throw new IllegalArgumentException("Input must be a string.");
}
String userLevel = (String) input;
return modelMap.getOrDefault(userLevel, "default-model");
}
}
- CompositeModelResolver: 组合多个
ModelResolver,按照优先级依次尝试解析模型名称。
public class CompositeModelResolver implements ModelResolver {
private final List<ModelResolver> resolvers;
public CompositeModelResolver(List<ModelResolver> resolvers) {
this.resolvers = resolvers;
}
@Override
public String resolve(Object input) {
for (ModelResolver resolver : resolvers) {
String modelName = resolver.resolve(input);
if (modelName != null && !modelName.isEmpty()) {
return modelName;
}
}
return null; // Or throw an exception if no model can be resolved
}
}
3. 使用 ModelResolver
在业务代码中,我们只需要注入 ModelResolver 接口,即可获取模型名称。
public class TextSummarizer {
private final OpenAI openAI;
private final ModelResolver modelResolver;
public TextSummarizer(OpenAI openAI, ModelResolver modelResolver) {
this.openAI = openAI;
this.modelResolver = modelResolver;
}
public String summarize(String text) {
String modelName = modelResolver.resolve(text);
CompletionRequest completionRequest = CompletionRequest.builder()
.model(modelName)
.prompt(text)
.maxTokens(100)
.build();
return openAI.createCompletion(completionRequest).getChoices().get(0).getText();
}
}
4. 配置 ModelResolver
我们可以通过配置文件(例如 application.properties 或 application.yml)来配置 ModelResolver。
例如,使用 Spring Boot 配置文件:
model.resolver.type: composite
model.resolver.length.short-text-model: gpt-3.5-turbo
model.resolver.length.long-text-model: gpt-4
model.resolver.length.max-length: 500
model.resolver.user-level.model-map.basic: gpt-3.5-turbo
model.resolver.user-level.model-map.premium: gpt-4
model.resolver.user-level.default-model: gpt-3.5-turbo
然后,我们可以编写一个配置类来创建 ModelResolver 实例。
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Configuration
public class ModelResolverConfig {
@Value("${model.resolver.type}")
private String resolverType;
@Value("${model.resolver.length.short-text-model:gpt-3.5-turbo}")
private String shortTextModel;
@Value("${model.resolver.length.long-text-model:gpt-4}")
private String longTextModel;
@Value("${model.resolver.length.max-length:500}")
private int maxLength;
@Value("#{${model.resolver.user-level.model-map:{basic:gpt-3.5-turbo,premium:gpt-4}}}")
private Map<String, String> userLevelModelMap;
@Value("${model.resolver.user-level.default-model:gpt-3.5-turbo}")
private String defaultUserLevelModel;
@Bean
public ModelResolver modelResolver() {
if ("composite".equalsIgnoreCase(resolverType)) {
return compositeModelResolver();
} else if ("length".equalsIgnoreCase(resolverType)) {
return lengthBasedModelResolver();
} else if ("user-level".equalsIgnoreCase(resolverType)) {
return userLevelModelResolver();
}
throw new IllegalArgumentException("Invalid resolver type: " + resolverType);
}
@Bean
public LengthBasedModelResolver lengthBasedModelResolver() {
return new LengthBasedModelResolver(shortTextModel, longTextModel, maxLength);
}
@Bean
public UserLevelModelResolver userLevelModelResolver() {
return new UserLevelModelResolver(userLevelModelMap);
}
@Bean
public CompositeModelResolver compositeModelResolver() {
List<ModelResolver> resolvers = Arrays.asList(lengthBasedModelResolver(), userLevelModelResolver());
return new CompositeModelResolver(resolvers);
}
}
5. Spring Boot 集成示例
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.service.OpenAI;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
public class TextSummarizer {
private final OpenAI openAI;
private final ModelResolver modelResolver;
@Autowired
public TextSummarizer(OpenAI openAI, ModelResolver modelResolver) {
this.openAI = openAI;
this.modelResolver = modelResolver;
}
public String summarize(String text, String userLevel) {
// 可以传入多个参数,例如文本内容和用户级别
// 在 CompositeModelResolver 中,可以先根据文本长度选择模型,再根据用户级别进行调整
String modelName = modelResolver.resolve(text);
CompletionRequest completionRequest = CompletionRequest.builder()
.model(modelName)
.prompt(text)
.maxTokens(100)
.build();
return openAI.createCompletion(completionRequest).getChoices().get(0).getText();
}
public String summarize(String text) {
String modelName = modelResolver.resolve(text);
CompletionRequest completionRequest = CompletionRequest.builder()
.model(modelName)
.prompt(text)
.maxTokens(100)
.build();
return openAI.createCompletion(completionRequest).getChoices().get(0).getText();
}
}
6. 扩展性考虑
ModelResolver 架构具有良好的扩展性。如果需要增加新的模型选择策略,只需要实现 ModelResolver 接口,并将其注册到 CompositeModelResolver 中即可。
例如,我们可以实现一个 ABTestModelResolver,根据 A/B 测试结果选择模型。
public class ABTestModelResolver implements ModelResolver {
private final String modelA;
private final String modelB;
private final double probabilityA;
public ABTestModelResolver(String modelA, String modelB, double probabilityA) {
this.modelA = modelA;
this.modelB = modelB;
this.probabilityA = probabilityA;
}
@Override
public String resolve(Object input) {
double random = Math.random();
return random < probabilityA ? modelA : modelB;
}
}
然后,将 ABTestModelResolver 注册到 CompositeModelResolver 中。
@Bean
public ABTestModelResolver abTestModelResolver() {
return new ABTestModelResolver("gpt-3.5-turbo", "gpt-4", 0.5);
}
@Bean
public CompositeModelResolver compositeModelResolver(LengthBasedModelResolver lengthBasedModelResolver, UserLevelModelResolver userLevelModelResolver, ABTestModelResolver abTestModelResolver) {
List<ModelResolver> resolvers = Arrays.asList(lengthBasedModelResolver, userLevelModelResolver, abTestModelResolver);
return new CompositeModelResolver(resolvers);
}
7. 单元测试
ModelResolver 架构使得单元测试更加容易。我们可以针对每个 ModelResolver 接口编写单元测试,验证其模型选择逻辑的正确性。
例如,针对 LengthBasedModelResolver 的单元测试:
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class LengthBasedModelResolverTest {
@Test
public void testResolveShortText() {
LengthBasedModelResolver resolver = new LengthBasedModelResolver("gpt-3.5-turbo", "gpt-4", 500);
String modelName = resolver.resolve("This is a short text.");
assertEquals("gpt-3.5-turbo", modelName);
}
@Test
public void testResolveLongText() {
LengthBasedModelResolver resolver = new LengthBasedModelResolver("gpt-3.5-turbo", "gpt-4", 500);
String modelName = resolver.resolve("This is a long text that exceeds the maximum length.");
assertEquals("gpt-4", modelName);
}
@Test
public void testInvalidInput() {
LengthBasedModelResolver resolver = new LengthBasedModelResolver("gpt-3.5-turbo", "gpt-4", 500);
assertThrows(IllegalArgumentException.class, () -> resolver.resolve(123));
}
}
代码示例:完整的可运行示例
// 依赖添加:
// implementation("com.theokanning.openai:openai-java:0.17.0")
// implementation("org.springframework.boot:spring-boot-starter-web:3.0.2")
// maven
//
// <dependency>
// <groupId>com.theokanning.openai</groupId>
// <artifactId>openai-java</artifactId>
// <version>0.17.0</version>
// </dependency>
// <dependency>
// <groupId>org.springframework.boot</groupId>
// <artifactId>spring-boot-starter-web</artifactId>
// <version>3.0.2</version>
// </dependency>
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.service.OpenAI;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Service;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
interface ModelResolver {
String resolve(Object input);
}
class LengthBasedModelResolver implements ModelResolver {
private final String shortTextModel;
private final String longTextModel;
private final int maxLength;
public LengthBasedModelResolver(String shortTextModel, String longTextModel, int maxLength) {
this.shortTextModel = shortTextModel;
this.longTextModel = longTextModel;
this.maxLength = maxLength;
}
@Override
public String resolve(Object input) {
if (!(input instanceof String)) {
throw new IllegalArgumentException("Input must be a string.");
}
String text = (String) input;
return text.length() <= maxLength ? shortTextModel : longTextModel;
}
}
class UserLevelModelResolver implements ModelResolver {
private final Map<String, String> modelMap;
public UserLevelModelResolver(Map<String, String> modelMap) {
this.modelMap = modelMap;
}
@Override
public String resolve(Object input) {
if (!(input instanceof String)) {
throw new IllegalArgumentException("Input must be a string.");
}
String userLevel = (String) input;
return modelMap.getOrDefault(userLevel, "default-model");
}
}
class CompositeModelResolver implements ModelResolver {
private final List<ModelResolver> resolvers;
public CompositeModelResolver(List<ModelResolver> resolvers) {
this.resolvers = resolvers;
}
@Override
public String resolve(Object input) {
for (ModelResolver resolver : resolvers) {
String modelName = resolver.resolve(input);
if (modelName != null && !modelName.isEmpty()) {
return modelName;
}
}
return null; // Or throw an exception if no model can be resolved
}
}
@Service
class TextSummarizer {
private final OpenAI openAI;
private final ModelResolver modelResolver;
public TextSummarizer(OpenAI openAI, ModelResolver modelResolver) {
this.openAI = openAI;
this.modelResolver = modelResolver;
}
public String summarize(String text, String userLevel) {
String modelName = modelResolver.resolve(text);
CompletionRequest completionRequest = CompletionRequest.builder()
.model(modelName)
.prompt(text)
.maxTokens(100)
.build();
return openAI.createCompletion(completionRequest).getChoices().get(0).getText();
}
public String summarize(String text) {
String modelName = modelResolver.resolve(text);
CompletionRequest completionRequest = CompletionRequest.builder()
.model(modelName)
.prompt(text)
.maxTokens(100)
.build();
return openAI.createCompletion(completionRequest).getChoices().get(0).getText();
}
}
@Configuration
class ModelResolverConfig {
@Value("${model.resolver.type}")
private String resolverType;
@Value("${model.resolver.length.short-text-model:gpt-3.5-turbo}")
private String shortTextModel;
@Value("${model.resolver.length.long-text-model:gpt-4}")
private String longTextModel;
@Value("${model.resolver.length.max-length:500}")
private int maxLength;
@Value("#{${model.resolver.user-level.model-map:{basic:gpt-3.5-turbo,premium:gpt-4}}}")
private Map<String, String> userLevelModelMap;
@Value("${model.resolver.user-level.default-model:gpt-3.5-turbo}")
private String defaultUserLevelModel;
@Bean
public ModelResolver modelResolver() {
if ("composite".equalsIgnoreCase(resolverType)) {
return compositeModelResolver();
} else if ("length".equalsIgnoreCase(resolverType)) {
return lengthBasedModelResolver();
} else if ("user-level".equalsIgnoreCase(resolverType)) {
return userLevelModelResolver();
}
throw new IllegalArgumentException("Invalid resolver type: " + resolverType);
}
@Bean
public LengthBasedModelResolver lengthBasedModelResolver() {
return new LengthBasedModelResolver(shortTextModel, longTextModel, maxLength);
}
@Bean
public UserLevelModelResolver userLevelModelResolver() {
return new UserLevelModelResolver(userLevelModelMap);
}
@Bean
public CompositeModelResolver compositeModelResolver() {
List<ModelResolver> resolvers = Arrays.asList(lengthBasedModelResolver(), userLevelModelResolver());
return new CompositeModelResolver(resolvers);
}
}
@SpringBootApplication
@RestController
class DemoApplication {
@Value("${openai.api.key}")
private String openaiApiKey;
private final TextSummarizer textSummarizer;
public DemoApplication(TextSummarizer textSummarizer) {
this.textSummarizer = textSummarizer;
}
@Bean
public OpenAI openAI() {
return new OpenAI(openaiApiKey);
}
@GetMapping("/summarize")
public String summarize(@RequestParam String text) {
return textSummarizer.summarize(text);
}
public static void main(String[] args) {
SpringApplication.run(DemoApplication.class, args);
}
}
在 application.properties 中配置:
openai.api.key=YOUR_OPENAI_API_KEY
model.resolver.type=composite
model.resolver.length.short-text-model=gpt-3.5-turbo
model.resolver.length.long-text-model=gpt-4
model.resolver.length.max-length=500
model.resolver.user-level.model-map.basic=gpt-3.5-turbo
model.resolver.user-level.model-map.premium=gpt-4
model.resolver.user-level.default-model=gpt-3.5-turbo
架构优势:解耦、灵活、易维护
通过 ModelResolver 架构,我们实现了模型选择逻辑与业务代码的解耦,提高了代码的灵活性、可维护性和可测试性。
- 解耦: 模型选择逻辑不再分散在各个业务方法中,而是集中在
ModelResolver接口的实现类中。 - 灵活: 可以根据不同的需求实现不同的
ModelResolver接口,例如根据文本长度、用户级别、A/B 测试结果等选择模型。 - 易维护: 修改模型选择策略只需要修改
ModelResolver接口的实现类或配置文件,而不需要修改大量的业务代码。 - 易测试: 可以针对每个
ModelResolver接口编写单元测试,验证其模型选择逻辑的正确性。
局限性:复杂性增加
虽然 ModelResolver 架构带来了诸多优势,但也增加了一些复杂性。需要设计和维护 ModelResolver 接口及其实现类,并配置 CompositeModelResolver。但是,这种复杂性是值得的,因为它带来了更好的代码结构和可维护性。
未来方向:更智能的模型选择
ModelResolver 架构为我们提供了一个统一的模型解析方案,但模型选择策略仍然需要人工定义。未来,我们可以探索更智能的模型选择方法,例如使用机器学习模型自动选择合适的模型。
例如,我们可以训练一个模型,根据输入文本的特征(例如主题、情感、复杂度等)来选择最合适的模型。
模型选择逻辑的抽象与复用
通过 ModelResolver 架构,我们将模型选择逻辑从业务代码中抽离出来,实现了模型选择逻辑的抽象与复用。这种架构不仅适用于 OpenAI 模型,也适用于其他需要根据不同条件选择不同服务的场景。它提供了一种清晰、可扩展的方式来管理和维护多个模型,降低了代码的复杂性,并提高了系统的灵活性。