View Javadoc
1   package net.sf.okapi.steps.llmsentencealigner;
2   
3   import java.util.ArrayList;
4   import java.util.Iterator;
5   import java.util.List;
6   
7   import com.acumenvelocity.ath.common.AlignmentData.CombinedAlignment;
8   import com.acumenvelocity.ath.common.AlignmentData.CombinedAlignmentInput;
9   import com.acumenvelocity.ath.common.AlignmentData.CombinedAlignmentOutput;
10  import com.acumenvelocity.ath.common.AlignmentData.ParagraphAlignment;
11  import com.acumenvelocity.ath.common.AlignmentData.ParagraphWithSegments;
12  import com.acumenvelocity.ath.common.AlignmentData.SegmentInfo;
13  import com.acumenvelocity.ath.common.AlignmentData.SentenceAlignment;
14  import com.acumenvelocity.ath.common.ConversionUtil;
15  import com.acumenvelocity.ath.common.Log;
16  import com.acumenvelocity.ath.common.OkapiUtil;
17  import com.acumenvelocity.ath.gemini.GenAi;
18  import com.acumenvelocity.ath.model.InlineCode;
19  import com.acumenvelocity.ath.model.x.LayeredTextX;
20  import com.acumenvelocity.ath.steps.BaseAlignerStep;
21  
22  import net.sf.okapi.common.IParameters;
23  import net.sf.okapi.common.IResource;
24  import net.sf.okapi.common.UsingParameters;
25  import net.sf.okapi.common.Util;
26  import net.sf.okapi.common.exceptions.OkapiException;
27  import net.sf.okapi.common.filters.IFilter;
28  import net.sf.okapi.common.resource.AlignmentStatus;
29  import net.sf.okapi.common.resource.ITextUnit;
30  import net.sf.okapi.common.resource.Segment;
31  import net.sf.okapi.common.resource.TextContainer;
32  import net.sf.okapi.common.resource.TextFragment;
33  import net.sf.okapi.common.resource.TextFragmentUtil;
34  import net.sf.okapi.common.resource.TextUnitUtil;
35  
36  @UsingParameters(LlmSentenceAlignerParameters.class)
37  public class LlmSentenceAlignerStep extends BaseAlignerStep {
38  
39    private LlmSentenceAlignerParameters params;
40  
41    private final List<List<SegmentDataWithCodes>> sourceSegmentData = new ArrayList<>();
42    private final List<List<SegmentDataWithCodes>> targetSegmentData = new ArrayList<>();
43  
44    public LlmSentenceAlignerStep(IFilter targetFilter) {
45      super(targetFilter);
46      params = new LlmSentenceAlignerParameters();
47    }
48  
49    private static class SegmentDataWithCodes {
50      String text;
51      List<InlineCode> codes = new ArrayList<>();
52    }
53  
54    @Override
55    public String getName() {
56      return "LLM Sentence Alignment";
57    }
58  
59    @Override
60    public String getDescription() {
61      return "Aligns paragraphs and sentences using LLM. Handles crossed paragraphs and different document structures.";
62    }
63  
64    @Override
65    public LlmSentenceAlignerParameters getParameters() {
66      return params;
67    }
68  
69    @Override
70    public void setParameters(IParameters params) {
71      this.params = (LlmSentenceAlignerParameters) params;
72    }
73  
74    @Override
75    protected boolean isSegmentSource() {
76      return params.isSegmentSource();
77    }
78  
79    @Override
80    protected boolean isSegmentTarget() {
81      return params.isSegmentTarget();
82    }
83  
84    @Override
85    protected boolean isUseCustomSourceRules() {
86      return params.isUseCustomSourceRules();
87    }
88  
89    @Override
90    protected boolean isUseCustomTargetRules() {
91      return params.isUseCustomTargetRules();
92    }
93  
94    @Override
95    protected String getCustomSourceRulesPath() {
96      return params.getCustomSourceRulesPath();
97    }
98  
99    @Override
100   protected String getCustomTargetRulesPath() {
101     return params.getCustomTargetRulesPath();
102   }
103 
104   @Override
105   protected boolean isCollapseWhitespace() {
106     return params.isCollapseWhitespace();
107   }
108 
109   @Override
110   protected void clear() {
111     super.clear();
112     sourceSegmentData.clear();
113     targetSegmentData.clear();
114   }
115 
116   @Override
117   protected void performAlignment(List<ITextUnit> sourceTUs, List<ITextUnit> targetTUs) {
118     // Single LLM call for both paragraph and sentence alignment
119     CombinedAlignmentOutput alignmentOutput = performCombinedAlignment(sourceTUs, targetTUs);
120 
121     // Apply alignments
122     applyCombinedAlignments(alignmentOutput, sourceTUs, targetTUs);
123 
124     // Set alignment origin metadata
125     for (ITextUnit tu : sourceTUs) {
126       OkapiUtil.setAlOrigin(tu, getSourceLocale(), getTargetLocale());
127     }
128   }
129 
130   private CombinedAlignmentOutput performCombinedAlignment(List<ITextUnit> sourceTUs,
131       List<ITextUnit> targetTUs) {
132     CombinedAlignmentInput input = new CombinedAlignmentInput();
133     input.sourceLanguage = getSourceLocale().toString();
134     input.targetLanguage = getTargetLocale().toString();
135     input.task = "Align paragraphs first, then align sentences within each paragraph pair";
136 
137     // Build source paragraphs with segments (codes removed)
138     for (ITextUnit srcTu : sourceTUs) {
139       ParagraphWithSegments pws = new ParagraphWithSegments();
140       pws.position = input.sourceParagraphs.size();
141       pws.id = srcTu.getId();
142       pws.context = srcTu.getName();
143 
144       List<SegmentDataWithCodes> segDataList = new ArrayList<>();
145 
146       for (Segment seg : srcTu.getSource().getSegments()) {
147         LayeredTextX slt = ConversionUtil.toLayeredText(seg.text, getSourceLocale());
148 
149         SegmentInfo si = new SegmentInfo();
150         si.position = pws.segments.size();
151         si.text = slt.getText();
152         pws.segments.add(si);
153 
154         SegmentDataWithCodes sdc = new SegmentDataWithCodes();
155         sdc.text = slt.getText();
156         sdc.codes = slt.getCodes();
157         segDataList.add(sdc);
158       }
159 
160       input.sourceParagraphs.add(pws);
161       sourceSegmentData.add(segDataList);
162     }
163 
164     // Build target paragraphs with segments (codes removed)
165     for (ITextUnit trgTu : targetTUs) {
166       ParagraphWithSegments pws = new ParagraphWithSegments();
167       pws.position = input.targetParagraphs.size();
168       pws.id = trgTu.getId();
169       pws.context = trgTu.getName();
170 
171       List<SegmentDataWithCodes> segDataList = new ArrayList<>();
172 
173       for (Segment seg : trgTu.getSource().getSegments()) {
174         LayeredTextX tlt = ConversionUtil.toLayeredText(seg.text, getTargetLocale());
175 
176         SegmentInfo si = new SegmentInfo();
177         si.position = pws.segments.size();
178         si.text = tlt.getText();
179         pws.segments.add(si);
180 
181         SegmentDataWithCodes sdc = new SegmentDataWithCodes();
182         sdc.text = tlt.getText();
183         sdc.codes = tlt.getCodes();
184         segDataList.add(sdc);
185       }
186 
187       input.targetParagraphs.add(pws);
188       targetSegmentData.add(segDataList);
189     }
190 
191     // LLM-based alignment
192     try {
193       return GenAi.alignParagraphsAndSentences(params.getAlignmentModelName(), input);
194 
195     } catch (Exception e) {
196       Log.error(getClass(), "LLM alignment failed: {}", e.getMessage(), e);
197       throw new OkapiException("LLM alignment failed", e);
198     }
199   }
200 
201   private void applyCombinedAlignments(CombinedAlignmentOutput output, List<ITextUnit> sourceTUs,
202       List<ITextUnit> targetTUs) {
203 
204     for (CombinedAlignment combined : output.alignments) {
205       ParagraphAlignment paraAlign = combined.paragraphAlignment;
206 
207       if (params.isLogAlignmentDetails()) {
208         Log.debug(getClass(), "Paragraph alignment: type={}, src={}, trg={}",
209             paraAlign.type, paraAlign.sourceParagraphPositions, paraAlign.targetParagraphPositions);
210       }
211 
212       if (Util.isEmpty(paraAlign.sourceParagraphPositions)
213           || Util.isEmpty(paraAlign.targetParagraphPositions)) {
214 
215         Log.warn(getClass(), "Problematic para alignment -- source: {}, target: {}",
216             paraAlign.sourceParagraphPositions, paraAlign.targetParagraphPositions);
217 
218         continue;
219       }
220 
221       // Collect source and target TUs for this paragraph pair
222       List<ITextUnit> srcTuGroup = new ArrayList<>();
223       List<List<SegmentDataWithCodes>> srcSegDataGroup = new ArrayList<>();
224 
225       for (int srcParaIndex : paraAlign.sourceParagraphPositions) {
226         srcTuGroup.add(sourceTUs.get(srcParaIndex));
227         srcSegDataGroup.add(sourceSegmentData.get(srcParaIndex));
228       }
229 
230       List<ITextUnit> trgTuGroup = new ArrayList<>();
231       List<List<SegmentDataWithCodes>> trgSegDataGroup = new ArrayList<>();
232 
233       for (int trgParaIndex : paraAlign.targetParagraphPositions) {
234         trgTuGroup.add(targetTUs.get(trgParaIndex));
235         trgSegDataGroup.add(targetSegmentData.get(trgParaIndex));
236       }
237 
238       applySentenceAlignmentsToParagraphPair(
239           srcTuGroup, trgTuGroup, srcSegDataGroup, trgSegDataGroup,
240           combined.sentenceAlignments);
241     }
242   }
243 
244   private void applySentenceAlignmentsToParagraphPair(
245       List<ITextUnit> srcTuGroup, List<ITextUnit> trgTuGroup,
246       List<List<SegmentDataWithCodes>> srcSegDataGroup,
247       List<List<SegmentDataWithCodes>> trgSegDataGroup,
248       List<SentenceAlignment> sentenceAlignments) {
249 
250     if (Util.isEmpty(srcTuGroup) || Util.isEmpty(trgTuGroup)) {
251       Log.warn(getClass(), "Problematic para alignment -- source: {}, target: {}",
252           srcTuGroup, trgTuGroup);
253 
254       return;
255     }
256 
257     if (Util.isEmpty(srcSegDataGroup) || Util.isEmpty(trgSegDataGroup)) {
258       Log.warn(getClass(), "Problematic sentence alignment -- source: {}, target: {}",
259           srcSegDataGroup, trgSegDataGroup);
260 
261       return;
262     }
263 
264     // Case 1: 1:1 paragraph match
265     if (srcTuGroup.size() == 1 && trgTuGroup.size() == 1) {
266       applySentenceAlignmentsToSinglePair(
267           srcTuGroup.get(0), trgTuGroup.get(0),
268           srcSegDataGroup.get(0), trgSegDataGroup.get(0),
269           sentenceAlignments);
270 
271       return;
272     }
273 
274     // Case 2: Multi-paragraph match - merge source TUs
275     ITextUnit primarySrcTu = srcTuGroup.get(0);
276     List<SegmentDataWithCodes> mergedSrcSegData = new ArrayList<>(srcSegDataGroup.get(0));
277 
278     for (int i = 1; i < srcTuGroup.size(); i++) {
279       ITextUnit additionalTu = srcTuGroup.get(i);
280 
281       for (Segment seg : additionalTu.getSource().getSegments()) {
282         primarySrcTu.getSource().append(seg.clone());
283       }
284 
285       mergedSrcSegData.addAll(srcSegDataGroup.get(i));
286     }
287 
288     // Merge target TUs if multiple
289     ITextUnit primaryTrgTu = trgTuGroup.isEmpty() ? null : trgTuGroup.get(0);
290 
291     List<SegmentDataWithCodes> mergedTrgSegData = trgTuGroup.isEmpty()
292         ? new ArrayList<>()
293         : new ArrayList<>(trgSegDataGroup.get(0));
294 
295     if (primaryTrgTu != null && trgTuGroup.size() > 1) {
296       for (int i = 1; i < trgTuGroup.size(); i++) {
297         ITextUnit additionalTu = trgTuGroup.get(i);
298 
299         for (Segment seg : additionalTu.getSource().getSegments()) {
300           primaryTrgTu.getSource().append(seg.clone());
301         }
302 
303         mergedTrgSegData.addAll(trgSegDataGroup.get(i));
304       }
305     }
306 
307     if (primaryTrgTu != null) {
308       applySentenceAlignmentsToSinglePair(
309           primarySrcTu, primaryTrgTu, mergedSrcSegData, mergedTrgSegData,
310           sentenceAlignments);
311     }
312   }
313 
314   private void applySentenceAlignmentsToSinglePair(
315       ITextUnit sourceTu, ITextUnit targetTu,
316       List<SegmentDataWithCodes> srcSegData, List<SegmentDataWithCodes> trgSegData,
317       List<SentenceAlignment> sentenceAlignments) {
318 
319     TextContainer srcCont = sourceTu.getSource();
320     TextContainer trgCont = sourceTu.createTarget(getTargetLocale(), false, IResource.CREATE_EMPTY);
321     trgCont.clear();
322 
323     List<Segment> srcSegments = new ArrayList<>(srcCont.getSegments().asList());
324     int nextSegmentId = 1;
325     int currentSrcPos = 0;
326 
327     for (SentenceAlignment align : sentenceAlignments) {
328       if (Util.isEmpty(align.sourcePositions) || Util.isEmpty(align.targetPositions)) {
329 
330         Log.warn(getClass(), "Problematic sentence alignment -- source: {}, target: {}",
331             align.sourcePositions, align.targetPositions);
332 
333         continue;
334       }
335 
336       if (params.isLogAlignmentDetails()) {
337         Log.debug(getClass(), "Sentence alignment: type={}, src={}, trg={}",
338             align.type, align.sourcePositions, align.targetPositions);
339       }
340 
341       switch (align.type) {
342       case "MATCH":
343         int srcPos = align.sourcePositions.get(0);
344         int trgPos = align.targetPositions.get(0);
345 
346         String segId = srcSegments.get(srcPos).getId();
347         SegmentDataWithCodes trgData = trgSegData.get(trgPos);
348 
349         // Restore codes to target text
350         TextFragment trgFrag = ConversionUtil.toTextFragment(
351             new LayeredTextX().text(trgData.text).codes(trgData.codes)
352                 .language(getTargetLocale().toString()));
353 
354         trgCont.append(new Segment(segId, trgFrag));
355         currentSrcPos = srcPos + 1;
356         break;
357 
358       case "MULTI_MATCH":
359         int firstSrcPos = align.sourcePositions.get(0);
360         String groupId = srcSegments.get(firstSrcPos).getId();
361 
362         // Merge source segments
363         for (int i = 0; i < align.sourcePositions.size() - 1; i++) {
364           srcCont.getSegments().joinWithNext(firstSrcPos);
365         }
366 
367         // Merge target segments with codes
368         List<SegmentDataWithCodes> trgGroup = new ArrayList<>();
369 
370         for (int tPos : align.targetPositions) {
371           trgGroup.add(trgSegData.get(tPos));
372         }
373 
374         TextFragment mergedFrag = mergeSegmentDataWithCodes(trgGroup);
375 
376         trgCont.append(new Segment(groupId, mergedFrag));
377         currentSrcPos = firstSrcPos + 1;
378         break;
379 
380       case "DELETED":
381         int delSrcPos = align.sourcePositions.get(0);
382         String delSrcId = srcSegments.get(delSrcPos).getId();
383 
384         trgCont.append(new Segment(delSrcId, new TextFragment("")));
385         currentSrcPos = delSrcPos + 1;
386         break;
387 
388       case "INSERTED":
389         int insTrgPos = align.targetPositions.get(0);
390         SegmentDataWithCodes insTrgData = trgSegData.get(insTrgPos);
391 
392         String newId = generateUniqueSegmentId(sourceTu, nextSegmentId++);
393 
394         TextFragment emptySrcFrag = new TextFragment("");
395         srcCont.getSegments().insert(currentSrcPos, new Segment(newId, emptySrcFrag));
396 
397         TextFragment insTrgFrag = ConversionUtil.toTextFragment(
398             new LayeredTextX().text(insTrgData.text).codes(insTrgData.codes)
399                 .language(getTargetLocale().toString()));
400 
401         trgCont.append(new Segment(newId, insTrgFrag));
402         currentSrcPos++;
403         break;
404       }
405     }
406 
407     // Verify segment counts match
408     int srcCount = srcCont.getSegments().count();
409     int trgCount = trgCont.getSegments().count();
410 
411     if (srcCount != trgCount) {
412       throw new OkapiException(
413           String.format("Segment count mismatch in TU '%s': source=%d, target=%d",
414               sourceTu.getId(), srcCount, trgCount));
415     }
416 
417     // Verify all segment IDs match
418     Iterator<Segment> srcIt = srcCont.getSegments().iterator();
419     Iterator<Segment> trgIt = trgCont.getSegments().iterator();
420 
421     while (srcIt.hasNext() && trgIt.hasNext()) {
422       Segment srcSeg = srcIt.next();
423       Segment trgSeg = trgIt.next();
424 
425       if (!srcSeg.getId().equals(trgSeg.getId())) {
426         throw new OkapiException(
427             String.format("Segment ID mismatch in TU '%s': source='%s', target='%s'",
428                 sourceTu.getId(), srcSeg.getId(), trgSeg.getId()));
429       }
430     }
431 
432     trgCont.setHasBeenSegmentedFlag(true);
433     trgCont.getSegments().setAlignmentStatus(AlignmentStatus.ALIGNED);
434 
435     // Align and copy code metadata from source to target
436     if (params.isLogAlignmentDetails()) {
437       Log.debug(getClass(), "Aligning codes between source and target for TU: {}",
438           sourceTu.getId());
439     }
440 
441     srcIt = srcCont.getSegments().iterator();
442     trgIt = trgCont.getSegments().iterator();
443 
444     while (srcIt.hasNext() && trgIt.hasNext()) {
445       Segment srcSeg = srcIt.next();
446       Segment trgSeg = trgIt.next();
447 
448       if (params.isUseCodesReinsertionModel()) {
449         TextUnitUtil.removeCodes(trgSeg.getContent());
450 
451       } else {
452         // LLM-based code re-insertion does this normalization, called only for no-LLM
453         OkapiUtil.removeExtraCodes(srcSeg.getContent().getCodes(), trgSeg.getContent());
454 
455         // Align codes and copy metadata from source to target
456         TextFragmentUtil.alignAndCopyCodeMetadata(srcSeg.text, trgSeg.text, true, true);
457 
458         // Rearrange opening and closing codes
459         OkapiUtil.rearrangeCodes(srcSeg.getContent().getCodes(), trgSeg.getContent());
460       }
461     }
462   }
463 
464   private TextFragment mergeSegmentDataWithCodes(List<SegmentDataWithCodes> segDataList) {
465     if (segDataList.isEmpty()) {
466       return new TextFragment("");
467     }
468 
469     if (segDataList.size() == 1) {
470       SegmentDataWithCodes data = segDataList.get(0);
471 
472       return ConversionUtil.toTextFragment(
473           new LayeredTextX().text(data.text).codes(data.codes)
474               .language(getTargetLocale().toString()));
475     }
476 
477     // Merge texts and codes
478     StringBuilder mergedText = new StringBuilder();
479     List<InlineCode> mergedCodes = new ArrayList<>();
480 
481     int cumulativeOffset = 0;
482 
483     for (int i = 0; i < segDataList.size(); i++) {
484       SegmentDataWithCodes data = segDataList.get(i);
485 
486       if (i > 0) {
487         mergedText.append(" ");
488         cumulativeOffset++;
489       }
490 
491       mergedText.append(data.text);
492 
493       // Adjust code positions
494       for (InlineCode code : data.codes) {
495         InlineCode adjustedCode = new InlineCode();
496         adjustedCode.setId(code.getId());
497         adjustedCode.setPosition(code.getPosition() + cumulativeOffset);
498         adjustedCode.setTagType(code.getTagType());
499         adjustedCode.setType(code.getType());
500         adjustedCode.setData(code.getData());
501         adjustedCode.setOuterData(code.getOuterData());
502         adjustedCode.setFlag(code.getFlag());
503         adjustedCode.setDisplayText(code.getDisplayText());
504         adjustedCode.setOriginalId(code.getOriginalId());
505         mergedCodes.add(adjustedCode);
506       }
507 
508       cumulativeOffset += data.text.length();
509     }
510 
511     return ConversionUtil.toTextFragment(
512         new LayeredTextX().text(mergedText.toString()).codes(mergedCodes)
513             .language(getTargetLocale().toString()));
514   }
515 
516   private String generateUniqueSegmentId(ITextUnit tu, int counter) {
517     String candidateId;
518 
519     do {
520       candidateId = tu.getId() + "_seg_" + counter++;
521 
522     } while (tu.getSource().getSegments().get(candidateId) != null);
523 
524     return candidateId;
525   }
526 }