View Javadoc
1   package com.acumenvelocity.ath.steps;
2   
3   import java.util.ArrayList;
4   import java.util.List;
5   
6   import com.acumenvelocity.ath.common.AthUtil;
7   import com.acumenvelocity.ath.common.Const;
8   import com.acumenvelocity.ath.common.ControllerUtil;
9   import com.acumenvelocity.ath.common.Log;
10  import com.acumenvelocity.ath.common.OkapiUtil;
11  import com.acumenvelocity.ath.common.OriginalTuAnnotation;
12  import com.acumenvelocity.ath.model.MtResources;
13  import com.acumenvelocity.ath.mt.confidence.ConfidenceScoredTranslation;
14  import com.acumenvelocity.ath.mt.confidence.HybridDataStructures.HybridEvaluationResult;
15  import com.acumenvelocity.ath.mt.confidence.HybridDataStructures.HybridScoredSegment;
16  import com.acumenvelocity.ath.mt.confidence.HybridDataStructures.HybridTranslationScore;
17  import com.acumenvelocity.ath.mt.confidence.HybridQualityEstimator;
18  
19  import net.sf.okapi.common.Event;
20  import net.sf.okapi.common.IResource;
21  import net.sf.okapi.common.Util;
22  import net.sf.okapi.common.annotation.AltTranslationsAnnotation;
23  import net.sf.okapi.common.query.MatchType;
24  import net.sf.okapi.common.resource.ITextUnit;
25  import net.sf.okapi.common.resource.Segment;
26  import net.sf.okapi.common.resource.TextContainer;
27  import net.sf.okapi.common.resource.TextFragment;
28  import net.sf.okapi.common.resource.TextFragmentUtil;
29  import net.sf.okapi.lib.translation.QueryUtil;
30  
31  /**
32   * MT Confidence Scoring Step for Okapi Framework v1.47.0
33   * 
34   * Translates source segments using multiple MT models (Google Cloud Translate v3 NMT,
35   * Translation LLM, and custom AutoML models), evaluates translation quality using
36   * Vertex AI (MetricX) and heuristic methods, then attaches confidence-scored
37   * alternate translations to each segment.
38   */
39  public class MtConfidenceScoringStep extends BaseTuBatchProcessingStep {
40  
41    private final List<MtResources> mtCustomResources;
42    private final List<MtResources> modelConfigs = new ArrayList<>();
43    private final List<String> sourceTexts = new ArrayList<>();
44    private final List<TextFragment> sourceTfs = new ArrayList<>();
45    private final List<SegmentInfo> segmentInfos = new ArrayList<>();
46    private final boolean mtSendPlainText;
47    private final QueryUtil qutil = new QueryUtil();
48  
49    private HybridEvaluationResult results;
50  
51    /**
52     * Creates a new MT Confidence Scoring Step.
53     * 
54     * @param mtCustomResources List of user's custom models and glossaries
55     */
56    public MtConfidenceScoringStep(List<MtResources> mtCustomResources, boolean mtSendPlainText) {
57      super();
58      this.mtCustomResources = mtCustomResources != null ? mtCustomResources : new ArrayList<>();
59      this.mtSendPlainText = mtSendPlainText;
60    }
61  
62    public String getName() {
63      return "MT Confidence Scoring";
64    }
65  
66    @Override
67    public String getDescription() {
68      return "Machine translates source segments using Google Cloud Translate API v3 "
69          + "(NMT, Translation LLM) and custom AutoML models/glossaries, then uses "
70          + "Google Vertex AI evaluation service and heuristics to calculate confidence "
71          + "scores for those translations.";
72    }
73  
74    /**
75     * Initialize model configurations with built-in Google models + custom models
76     */
77    private void initializeModelConfigs() {
78      modelConfigs.clear();
79  
80      // Add Google Cloud Translate v3 built-in models
81  
82      // 1. NMT (Neural Machine Translation) - standard model
83      MtResources nmt = new MtResources();
84  
85      nmt.setMtModelId("general/nmt");
86      nmt.setMtModelProjectId(ControllerUtil.getProjectId());
87      nmt.setMtModelProjectLocation(Const.US_CENTRAL1_PROJECT_LOCATION);
88  
89      modelConfigs.add(nmt);
90  
91      // 2. Translation LLM - advanced model
92      MtResources translationLlm = new MtResources();
93  
94      translationLlm.setMtModelId("general/translation-llm");
95      translationLlm.setMtModelProjectId(ControllerUtil.getProjectId());
96      translationLlm.setMtModelProjectLocation(Const.US_CENTRAL1_PROJECT_LOCATION);
97  
98      modelConfigs.add(translationLlm);
99  
100     // 3. Add user's custom models and glossaries
101     if (!Util.isEmpty(mtCustomResources)) {
102       modelConfigs.addAll(mtCustomResources);
103     }
104 
105     Log.info(MtConfidenceScoringStep.class,
106         "Initialized {} model configurations (2 built-in + {} custom)",
107         modelConfigs.size(), mtCustomResources != null ? mtCustomResources.size() : 0);
108   }
109 
110   /**
111    * Pre-process text units: collect source segments
112    */
113   private void preProcessTextUnit(ITextUnit tu) {
114     TextContainer source = tu.getSource();
115 
116     if (source == null) {
117       Log.error(getClass(), "Source of TU '{}' is null", tu.getId());
118       return;
119     }
120     
121     for (Segment segment : source.getSegments()) {
122       TextFragment srcTf = segment.getContent();
123 
124       // Skip empty segments
125       if (srcTf == null || srcTf.isEmpty()) {
126         Log.trace(getClass(), "Skipping empty segment in TU '{}'", tu.getId());
127         continue;
128       }
129       
130       sourceTfs.add(srcTf.clone());
131     }
132     
133     if (mtSendPlainText) {
134       // Store the original TU with codes, remove codes to improve MT quality
135       // CodesReinsertionStep will use OriginalTuAnnotation to get the original source codes
136       tu.setAnnotation(new OriginalTuAnnotation(tu.clone(), getSourceLocale()));
137       OkapiUtil.removeCodes(tu, true);
138     }
139 
140     // Process each segment in the text unit
141     for (Segment segment : source.getSegments()) {
142       TextFragment srcTf = segment.getContent();
143 
144       // Skip empty segments
145       if (srcTf == null || srcTf.isEmpty()) {
146         Log.trace(getClass(), "Skipping empty segment in TU '{}'", tu.getId());
147         continue;
148       }
149 
150       String sourceText = null;
151 
152       if (mtSendPlainText) {
153         sourceText = srcTf.getText();
154 
155       } else {
156         sourceText = qutil.toCodedHTML(srcTf);
157       }
158 
159       sourceTexts.add(sourceText);
160 
161       // Store segment info for later mapping
162       segmentInfos.add(new SegmentInfo(tu, segment.getId()));
163 
164       Log.trace(getClass(), "Collected segment [{}]: '{}'",
165           sourceTexts.size() - 1, sourceText);
166     }
167   }
168 
169   /**
170    * Post-process text units: apply translations with confidence scores
171    */
172   private void postProcessTextUnits(HybridEvaluationResult results) {
173     if (results == null || results.getSegments().isEmpty()) {
174       Log.warn(getClass(), "No evaluation results available");
175       return;
176     }
177 
178     List<HybridScoredSegment> scoredSegments = results.getSegments();
179 
180     if (scoredSegments.size() != segmentInfos.size()) {
181       Log.error(getClass(),
182           "Mismatch: {} scored segments but {} segment infos",
183           scoredSegments.size(), segmentInfos.size());
184 
185       return;
186     }
187 
188     // Process each scored segment
189     for (int i = 0; i < scoredSegments.size(); i++) {
190       HybridScoredSegment scoredSegment = scoredSegments.get(i);
191       SegmentInfo segInfo = segmentInfos.get(i);
192 
193       ITextUnit tu = segInfo.textUnit;
194       String segmentId = segInfo.segmentId;
195 
196       // Get or create target container
197       TextContainer target = tu.getTarget(getTargetLocale());
198 
199       if (target == null) {
200         target = tu.createTarget(getTargetLocale(), false, IResource.COPY_SEGMENTATION);
201         Log.trace(getClass(), "Created target container for TU '{}'", tu.getId());
202       }
203 
204       // Get or create target segment
205       Segment targetSegment = target.getSegments().get(segmentId);
206 
207       if (targetSegment == null) {
208         targetSegment = new Segment(segmentId);
209         target.append(targetSegment);
210         Log.trace(getClass(), "Created target segment '{}' in TU '{}'", segmentId, tu.getId());
211       }
212 
213       // Get all scored translations for this segment
214       List<HybridTranslationScore> scores = scoredSegment.getScores();
215 
216       if (scores.isEmpty()) {
217         Log.warn(getClass(), "No translations available for segment {} in TU '{}'",
218             i, tu.getId());
219 
220         continue;
221       }
222 
223       // Create or get AltTranslationsAnnotation for this segment
224       AltTranslationsAnnotation ata = targetSegment.getAnnotation(
225           AltTranslationsAnnotation.class);
226 
227       if (ata == null) {
228         ata = new AltTranslationsAnnotation();
229         targetSegment.setAnnotation(ata);
230         Log.trace(getClass(), "Created AltTranslationsAnnotation for segment '{}'", segmentId);
231       }
232 
233       // Add all translations as alternate translations with confidence scores
234       // TextFragment sourceTf = new TextFragment(scoredSegment.getSourceText());
235       TextFragment sourceTf = scoredSegment.getSourceTf();
236 
237       for (HybridTranslationScore score : scores) {
238         String targetText = score.getTranslation();
239         TextFragment targetTf = null;
240 
241         if (mtSendPlainText) {
242           targetTf = new TextFragment(targetText);
243 
244         } else {
245           targetTf = qutil.fromCodedHTMLToFragment(targetText, null);
246         }
247 
248         OkapiUtil.removeExtraCodes(sourceTf.getCodes(), targetTf);
249 
250         // Align codes and copy metadata from source to target
251         TextFragmentUtil.alignAndCopyCodeMetadata(sourceTf, targetTf, true, true);
252 
253         // Rearrange opening and closing codes
254         OkapiUtil.rearrangeCodes(sourceTf.getCodes(), targetTf);
255 
256         // Create ConfidenceScoredTranslation with all available scores
257         ConfidenceScoredTranslation cst = new ConfidenceScoredTranslation(
258             getSourceLocale(),
259             getTargetLocale(),
260             sourceTf,
261             sourceTf, // alternate source same as original for MT
262             targetTf,
263             MatchType.MT, // All are machine translations
264             AthUtil.extractLastSection(score.getModelId()),
265             score.getConfidence());
266 
267         // Add metadata
268         if (score.isAnomalyFlagged()) {
269           cst.setEngine("ANOMALY: " + score.getAnomalyReason());
270 
271         } else {
272           cst.setEngine(score.getMethod());
273         }
274 
275         ata.add(cst);
276 
277         Log.trace(getClass(),
278             "Added alt-trans [{}] confidence={} MetricX={} Heuristic={}: '{}'",
279             score.getModelId(),
280             score.getConfidence(),
281             score.getMetricXScore() != null ? String.format("%.2f", score.getMetricXScore())
282                 : "N/A",
283 
284             score.getHeuristicScore() != null ? String.format("%.3f", score.getHeuristicScore())
285                 : "N/A",
286 
287             score.getTranslation());
288       }
289 
290       // Sort by confidence (already sorted, but ensure it)
291       ata.sort();
292 
293       // Best translation (highest confidence)
294       ConfidenceScoredTranslation bestCst = (ConfidenceScoredTranslation) ata.getFirst();
295       TextFragment bestTranslation = bestCst.getTarget().getFirstContent();
296       targetSegment.setContent(bestTranslation);
297 
298       Log.debug(getClass(),
299           "Set best translation for TU '{}' segment '{}': '{}' (confidence: {})",
300           tu.getId(), segmentId, bestTranslation, bestCst.getConfidenceScore());
301     }
302   }
303 
304   /**
305    * Convert 0-1 double confidence to 0-100 integer percentage
306    */
307   public static int convertToPercentage(double confidence) {
308     return (int) Math.round(confidence * 100.0);
309   }
310 
311   @Override
312   protected void clear() {
313     sourceTexts.clear();
314     segmentInfos.clear();
315     results = null;
316   }
317 
318   @Override
319   protected void processTuEvents(List<Event> tuEvents) {
320     initializeModelConfigs();
321 
322     // Step 1: Pre-process all text units
323     Log.info(getClass(), "Pre-processing {} text units...", tuEvents.size());
324 
325     for (Event tue : tuEvents) {
326       ITextUnit tu = tue.getTextUnit();
327       preProcessTextUnit(tu);
328     }
329 
330     Log.info(getClass(), "Collected {} source segments from {} text units",
331         sourceTexts.size(), tuEvents.size());
332 
333     // Step 2: Evaluate translations with all models
334     if (!sourceTexts.isEmpty()) {
335       try (HybridQualityEstimator hqe = new HybridQualityEstimator(
336           ControllerUtil.getProjectId(),
337           Const.ATH_GCP_PROJECT_LOCATION,
338           mtSendPlainText)) {
339 
340         Log.info(getClass(),
341             "Evaluating translations: {} segments, {} models, {}→{}",
342             sourceTexts.size(),
343             modelConfigs.size(),
344             getSourceLocale(),
345             getTargetLocale());
346 
347         results = hqe.evaluateTranslations(
348             sourceTexts,
349             sourceTfs,
350             getSourceLocale().toString(),
351             getTargetLocale().toString(),
352             modelConfigs);
353 
354         Log.info(getClass(),
355             "Evaluation complete using strategy: {}",
356             results.getStrategy());
357 
358       } catch (Exception e) {
359         Log.error(getClass(), "Error during translation evaluation: {}", e.getMessage(), e);
360         // Continue with empty results - segments will remain untranslated
361       }
362 
363     } else {
364       Log.warn(getClass(), "No source segments to translate");
365     }
366 
367     // Step 3: Post-process text units with results
368     if (results != null) {
369       Log.info(getClass(), "Post-processing text units with evaluation results...");
370       postProcessTextUnits(results);
371     }
372 
373     Log.info(getClass(),
374         "MT Confidence Scoring complete: processed {} TUs, {} segments",
375         getNumProcessedTus(), sourceTexts.size());
376   }
377 }