1 package com.acumenvelocity.ath.gct.v3;
2
3 import java.io.FileInputStream;
4 import java.util.ArrayList;
5 import java.util.List;
6 import java.util.concurrent.CompletableFuture;
7 import java.util.stream.Collectors;
8
9 import com.acumenvelocity.ath.common.AthUtil;
10 import com.acumenvelocity.ath.common.Const;
11 import com.acumenvelocity.ath.common.Log;
12 import com.acumenvelocity.ath.steps.BatchMtStep;
13 import com.google.api.gax.core.FixedCredentialsProvider;
14 import com.google.auth.oauth2.GoogleCredentials;
15 import com.google.cloud.aiplatform.v1.LocationName;
16 import com.google.cloud.translate.v3.GlossaryName;
17 import com.google.cloud.translate.v3.TranslateTextGlossaryConfig;
18 import com.google.cloud.translate.v3.TranslateTextRequest;
19 import com.google.cloud.translate.v3.TranslateTextResponse;
20 import com.google.cloud.translate.v3.TranslationServiceClient;
21 import com.google.cloud.translate.v3.TranslationServiceSettings;
22
23 import net.sf.okapi.common.Util;
24
25 public class AthTranslation {
26
27 private static TranslationServiceClient client;
28
29 public static void init() throws Exception {
30 if (client == null || client.isShutdown()) {
31 TranslationServiceSettings.Builder settingsBuilder = TranslationServiceSettings.newBuilder();
32
33
34 if (!Util.isEmpty(Const.ATH_GCP_SECRET_FILE)) {
35 try (FileInputStream credentialsStream = new FileInputStream(Const.ATH_GCP_SECRET_FILE)) {
36 GoogleCredentials credentials = GoogleCredentials.fromStream(credentialsStream);
37 settingsBuilder.setCredentialsProvider(FixedCredentialsProvider.create(credentials));
38 }
39
40 } else if (!Util.isEmpty(Const.ATH_GCT_API_KEY)) {
41
42
43
44
45
46
47
48 Log.warn(AthTranslation.class,
49 "API Key provided, but Service Account Credentials (JSON file) are preferred for "
50 + "Google Cloud V3 Client.");
51 }
52
53 client = TranslationServiceClient.create(settingsBuilder.build());
54 }
55 }
56
57
58
59
60 public static List<String> translateBatch(
61 List<String> texts,
62 String sourceLang,
63 String targetLang,
64 String mimeType,
65 String projectId,
66 String projectLocation,
67 String mtModelProjectId,
68 String mtModelProjectLocation,
69 String mtModelId,
70 String mtGlossaryProjectId,
71 String mtGlossaryProjectLocation,
72 String mtGlossaryId) {
73
74 String fullModelName = null;
75
76 try {
77
78
79
80 final int MAX_TEXTS_PER_BATCH = 1000;
81 final int MAX_CODEPOINTS_PER_BATCH = 25000;
82
83 LocationName parent = LocationName.of(
84
85
86 projectId,
87 projectLocation);
88
89
90 if (!Util.isEmpty(mtModelId)) {
91 String modelProjectId = AthUtil.fallback(mtModelProjectId, projectId);
92
93 String modelProjectLocation = AthUtil.fallback(mtModelProjectLocation, projectLocation);
94
95
96
97
98
99
100 fullModelName = Log.format("projects/{}/locations/{}/models/{}", modelProjectId,
101 modelProjectLocation, mtModelId);
102 }
103
104
105 TranslateTextGlossaryConfig glossaryConfig = null;
106
107 if (!Util.isEmpty(mtGlossaryId)) {
108 String glossaryProjectId = AthUtil.fallback(mtGlossaryProjectId, projectId);
109
110 String glossaryProjectLocation = AthUtil.fallback(
111 mtGlossaryProjectLocation, projectLocation);
112
113 GlossaryName glossaryName = GlossaryName.of(
114 glossaryProjectId,
115 glossaryProjectLocation,
116 mtGlossaryId);
117
118 glossaryConfig = TranslateTextGlossaryConfig.newBuilder()
119 .setGlossary(glossaryName.toString())
120 .build();
121 }
122
123
124 List<List<String>> batches = createBatches(texts, MAX_TEXTS_PER_BATCH,
125 MAX_CODEPOINTS_PER_BATCH);
126
127
128 final String modelNameForLambda = fullModelName;
129 final TranslateTextGlossaryConfig glossaryConfigForLambda = glossaryConfig;
130
131 List<CompletableFuture<List<String>>> futures = batches.stream()
132 .map(batch -> CompletableFuture
133 .supplyAsync(
134 () -> translateSingleBatch(client, batch, sourceLang, targetLang, mimeType,
135 parent, modelNameForLambda, glossaryConfigForLambda)))
136 .collect(Collectors.toList());
137
138
139 List<String> allTranslations = new ArrayList<>();
140
141 for (CompletableFuture<List<String>> future : futures) {
142 List<String> batchResult = future.get();
143
144 if (batchResult == null) {
145 Log.error(BatchMtStep.class, "One of the translation batches failed");
146 return null;
147 }
148
149 allTranslations.addAll(batchResult);
150 }
151
152 return allTranslations;
153
154 } catch (Exception e) {
155 Log.error(BatchMtStep.class, e, "Translation failed for model '{}'", fullModelName);
156 }
157
158 return null;
159 }
160
161
162
163
164 private static List<List<String>> createBatches(List<String> texts, int maxTextsPerBatch,
165 int maxCodepointsPerBatch) {
166
167 List<List<String>> batches = new ArrayList<>();
168 List<String> currentBatch = new ArrayList<>();
169 int currentCodepoints = 0;
170
171 for (String text : texts) {
172 int textCodepoints = text.codePointCount(0, text.length());
173
174
175 if (!currentBatch.isEmpty() &&
176 (currentBatch.size() >= maxTextsPerBatch ||
177 currentCodepoints + textCodepoints > maxCodepointsPerBatch)) {
178 batches.add(currentBatch);
179 currentBatch = new ArrayList<>();
180 currentCodepoints = 0;
181 }
182
183 currentBatch.add(text);
184 currentCodepoints += textCodepoints;
185 }
186
187
188 if (!currentBatch.isEmpty()) {
189 batches.add(currentBatch);
190 }
191
192 return batches;
193 }
194
195
196
197
198 private static List<String> translateSingleBatch(
199 TranslationServiceClient client,
200 List<String> batch,
201 String sourceLang,
202 String targetLang,
203 String mimeType,
204 LocationName parent,
205 String fullModelName,
206 TranslateTextGlossaryConfig glossaryConfig) {
207
208 try {
209 TranslateTextRequest.Builder requestBuilder = TranslateTextRequest.newBuilder()
210 .setParent(parent.toString())
211 .setSourceLanguageCode(sourceLang)
212 .setTargetLanguageCode(targetLang)
213 .addAllContents(batch)
214 .setMimeType(mimeType);
215
216 if (fullModelName != null) {
217 requestBuilder.setModel(fullModelName);
218 }
219
220 if (glossaryConfig != null) {
221 requestBuilder.setGlossaryConfig(glossaryConfig);
222 }
223
224 TranslateTextResponse response = client.translateText(requestBuilder.build());
225
226 return response.getTranslationsList().stream()
227 .map(com.google.cloud.translate.v3.Translation::getTranslatedText)
228 .collect(Collectors.toList());
229
230 } catch (Exception e) {
231 Log.error(BatchMtStep.class, e, "Translation failed for batch");
232 return null;
233 }
234 }
235
236
237
238
239
240
241
242
243
244
245
246
247 public static List<String> translateBatch(List<String> translations, String sourceLang,
248 String targetLang, String mimeType, String projectId, String projectLocation) {
249
250 return translateBatch(translations, sourceLang, targetLang, mimeType, projectId,
251 projectLocation, null, null, null, null, null, null);
252 }
253
254 public static void done() throws Exception {
255 if (client != null) {
256 client.close();
257 }
258 }
259
260 public static TranslationServiceClient getClient() {
261 return client;
262 }
263 }