Bug description
I got an error message when my program calls TransformersEmbeddingModel
's embed()
method concurrently to compute many embedings, even when I called afterPropertiesSet()
before the concurrent calls. Here is the stacktrace:
java.lang.IllegalStateException: The engine PyTorch was not able to initialize
at ai.djl.engine.Engine.getEngine(Engine.java:218)
at ai.djl.engine.Engine.getInstance(Engine.java:149)
at ai.djl.ndarray.NDManager.newBaseManager(NDManager.java:120)
at org.springframework.ai.transformers.TransformersEmbeddingModel.call(TransformersEmbeddingModel.java:280)
at org.springframework.ai.transformers.TransformersEmbeddingModel.embed(TransformersEmbeddingModel.java:232)
at org.springframework.ai.transformers.TransformersEmbeddingModel.embed(TransformersEmbeddingModel.java:212)
at org.springframework.ai.transformers.TransformersEmbeddingModel.embed(TransformersEmbeddingModel.java:217)
at org.springframework.ai.transformers.TransformersEmbeddingModelTests.lambda$parallelEmbedDocument$0(TransformersEmbeddingModelTests.java:71)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
at java.base/java.lang.Thread.run(Thread.java:840)
java.lang.IllegalStateException: The engine PyTorch was not able to initialize
at ai.djl.engine.Engine.getEngine(Engine.java:218)
at ai.djl.engine.Engine.getInstance(Engine.java:149)
at ai.djl.ndarray.NDManager.newBaseManager(NDManager.java:120)
at org.springframework.ai.transformers.TransformersEmbeddingModel.call(TransformersEmbeddingModel.java:280)
at org.springframework.ai.transformers.TransformersEmbeddingModel.embed(TransformersEmbeddingModel.java:232)
at org.springframework.ai.transformers.TransformersEmbeddingModel.embed(TransformersEmbeddingModel.java:212)
at org.springframework.ai.transformers.TransformersEmbeddingModel.embed(TransformersEmbeddingModel.java:217)
at org.springframework.ai.transformers.TransformersEmbeddingModelTests.lambda$parallelEmbedDocument$0(TransformersEmbeddingModelTests.java:71)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
at java.base/java.lang.Thread.run(Thread.java:840)
...
I did some toubleshooting and investigation, my conclusion is that the djl 0.26.0 used by spring-ai has a known issue about what I encountered. And this isssue is fixed in djl 0.27.0. See: https://github.com/deepjavalibrary/djl/pull/3005.
I verified in djl 0.28.0 and the errors were disappeared. Thus, I suggest we upgrade djl to 0.28.0.
Environment Sping AI version: 0.8.1, v1.0.0-M1 Java: 17 vectore store: none
Steps to reproduce
You can use the test code I modified from TransformerEmbeddingModelTests
. Here the code to reproduce:
@Test
void parallelEmbedDocument() throws InterruptedException {
TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel();
try {
embeddingModel.afterPropertiesSet();
} catch (Exception e) {
throw new RuntimeException(e);
}
ExecutorService executorService = Executors.newFixedThreadPool(10);
for (int i = 0; i < 10; i++) {
executorService.execute(() -> {
try {
List<Double> embed = embeddingModel.embed(new Document("Hello world"));
assertThat(embed).hasSize(384);
assertThat(DF.format(embed.get(0))).isEqualTo(DF.format(-0.19744634628295898));
assertThat(DF.format(embed.get(383))).isEqualTo(DF.format(0.17298996448516846));
} catch (Exception e) {
e.printStackTrace();
}
});
}
executorService.shutdown();
executorService.awaitTermination(30, TimeUnit.SECONDS);
}
Expected behavior No errors about "The engine PyTorch was not able to initialize".
Minimal Complete Reproducible example See my comment above.
Comment From: ThomasVitale
@nichozhan Did https://github.com/spring-projects/spring-ai/pull/837 fix this issue?
Comment From: nichozhan
Hi @ThomasVitale , yes, I think so.