1 package net.sf.okapi.steps.llmsentencealigner;
2
3 import java.util.ArrayList;
4 import java.util.Iterator;
5 import java.util.List;
6
7 import com.acumenvelocity.ath.common.AlignmentData.CombinedAlignment;
8 import com.acumenvelocity.ath.common.AlignmentData.CombinedAlignmentInput;
9 import com.acumenvelocity.ath.common.AlignmentData.CombinedAlignmentOutput;
10 import com.acumenvelocity.ath.common.AlignmentData.ParagraphAlignment;
11 import com.acumenvelocity.ath.common.AlignmentData.ParagraphWithSegments;
12 import com.acumenvelocity.ath.common.AlignmentData.SegmentInfo;
13 import com.acumenvelocity.ath.common.AlignmentData.SentenceAlignment;
14 import com.acumenvelocity.ath.common.ConversionUtil;
15 import com.acumenvelocity.ath.common.Log;
16 import com.acumenvelocity.ath.common.OkapiUtil;
17 import com.acumenvelocity.ath.gemini.GenAi;
18 import com.acumenvelocity.ath.model.InlineCode;
19 import com.acumenvelocity.ath.model.x.LayeredTextX;
20 import com.acumenvelocity.ath.steps.BaseAlignerStep;
21
22 import net.sf.okapi.common.IParameters;
23 import net.sf.okapi.common.IResource;
24 import net.sf.okapi.common.UsingParameters;
25 import net.sf.okapi.common.Util;
26 import net.sf.okapi.common.exceptions.OkapiException;
27 import net.sf.okapi.common.filters.IFilter;
28 import net.sf.okapi.common.resource.AlignmentStatus;
29 import net.sf.okapi.common.resource.ITextUnit;
30 import net.sf.okapi.common.resource.Segment;
31 import net.sf.okapi.common.resource.TextContainer;
32 import net.sf.okapi.common.resource.TextFragment;
33 import net.sf.okapi.common.resource.TextFragmentUtil;
34 import net.sf.okapi.common.resource.TextUnitUtil;
35
36 @UsingParameters(LlmSentenceAlignerParameters.class)
37 public class LlmSentenceAlignerStep extends BaseAlignerStep {
38
39 private LlmSentenceAlignerParameters params;
40
41 private final List<List<SegmentDataWithCodes>> sourceSegmentData = new ArrayList<>();
42 private final List<List<SegmentDataWithCodes>> targetSegmentData = new ArrayList<>();
43
44 public LlmSentenceAlignerStep(IFilter targetFilter) {
45 super(targetFilter);
46 params = new LlmSentenceAlignerParameters();
47 }
48
49 private static class SegmentDataWithCodes {
50 String text;
51 List<InlineCode> codes = new ArrayList<>();
52 }
53
54 @Override
55 public String getName() {
56 return "LLM Sentence Alignment";
57 }
58
59 @Override
60 public String getDescription() {
61 return "Aligns paragraphs and sentences using LLM. Handles crossed paragraphs and different document structures.";
62 }
63
64 @Override
65 public LlmSentenceAlignerParameters getParameters() {
66 return params;
67 }
68
69 @Override
70 public void setParameters(IParameters params) {
71 this.params = (LlmSentenceAlignerParameters) params;
72 }
73
74 @Override
75 protected boolean isSegmentSource() {
76 return params.isSegmentSource();
77 }
78
79 @Override
80 protected boolean isSegmentTarget() {
81 return params.isSegmentTarget();
82 }
83
84 @Override
85 protected boolean isUseCustomSourceRules() {
86 return params.isUseCustomSourceRules();
87 }
88
89 @Override
90 protected boolean isUseCustomTargetRules() {
91 return params.isUseCustomTargetRules();
92 }
93
94 @Override
95 protected String getCustomSourceRulesPath() {
96 return params.getCustomSourceRulesPath();
97 }
98
99 @Override
100 protected String getCustomTargetRulesPath() {
101 return params.getCustomTargetRulesPath();
102 }
103
104 @Override
105 protected boolean isCollapseWhitespace() {
106 return params.isCollapseWhitespace();
107 }
108
109 @Override
110 protected void clear() {
111 super.clear();
112 sourceSegmentData.clear();
113 targetSegmentData.clear();
114 }
115
116 @Override
117 protected void performAlignment(List<ITextUnit> sourceTUs, List<ITextUnit> targetTUs) {
118
119 CombinedAlignmentOutput alignmentOutput = performCombinedAlignment(sourceTUs, targetTUs);
120
121
122 applyCombinedAlignments(alignmentOutput, sourceTUs, targetTUs);
123
124
125 for (ITextUnit tu : sourceTUs) {
126 OkapiUtil.setAlOrigin(tu, getSourceLocale(), getTargetLocale());
127 }
128 }
129
130 private CombinedAlignmentOutput performCombinedAlignment(List<ITextUnit> sourceTUs,
131 List<ITextUnit> targetTUs) {
132 CombinedAlignmentInput input = new CombinedAlignmentInput();
133 input.sourceLanguage = getSourceLocale().toString();
134 input.targetLanguage = getTargetLocale().toString();
135 input.task = "Align paragraphs first, then align sentences within each paragraph pair";
136
137
138 for (ITextUnit srcTu : sourceTUs) {
139 ParagraphWithSegments pws = new ParagraphWithSegments();
140 pws.position = input.sourceParagraphs.size();
141 pws.id = srcTu.getId();
142 pws.context = srcTu.getName();
143
144 List<SegmentDataWithCodes> segDataList = new ArrayList<>();
145
146 for (Segment seg : srcTu.getSource().getSegments()) {
147 LayeredTextX slt = ConversionUtil.toLayeredText(seg.text, getSourceLocale());
148
149 SegmentInfo si = new SegmentInfo();
150 si.position = pws.segments.size();
151 si.text = slt.getText();
152 pws.segments.add(si);
153
154 SegmentDataWithCodes sdc = new SegmentDataWithCodes();
155 sdc.text = slt.getText();
156 sdc.codes = slt.getCodes();
157 segDataList.add(sdc);
158 }
159
160 input.sourceParagraphs.add(pws);
161 sourceSegmentData.add(segDataList);
162 }
163
164
165 for (ITextUnit trgTu : targetTUs) {
166 ParagraphWithSegments pws = new ParagraphWithSegments();
167 pws.position = input.targetParagraphs.size();
168 pws.id = trgTu.getId();
169 pws.context = trgTu.getName();
170
171 List<SegmentDataWithCodes> segDataList = new ArrayList<>();
172
173 for (Segment seg : trgTu.getSource().getSegments()) {
174 LayeredTextX tlt = ConversionUtil.toLayeredText(seg.text, getTargetLocale());
175
176 SegmentInfo si = new SegmentInfo();
177 si.position = pws.segments.size();
178 si.text = tlt.getText();
179 pws.segments.add(si);
180
181 SegmentDataWithCodes sdc = new SegmentDataWithCodes();
182 sdc.text = tlt.getText();
183 sdc.codes = tlt.getCodes();
184 segDataList.add(sdc);
185 }
186
187 input.targetParagraphs.add(pws);
188 targetSegmentData.add(segDataList);
189 }
190
191
192 try {
193 return GenAi.alignParagraphsAndSentences(params.getAlignmentModelName(), input);
194
195 } catch (Exception e) {
196 Log.error(getClass(), "LLM alignment failed: {}", e.getMessage(), e);
197 throw new OkapiException("LLM alignment failed", e);
198 }
199 }
200
201 private void applyCombinedAlignments(CombinedAlignmentOutput output, List<ITextUnit> sourceTUs,
202 List<ITextUnit> targetTUs) {
203
204 for (CombinedAlignment combined : output.alignments) {
205 ParagraphAlignment paraAlign = combined.paragraphAlignment;
206
207 if (params.isLogAlignmentDetails()) {
208 Log.debug(getClass(), "Paragraph alignment: type={}, src={}, trg={}",
209 paraAlign.type, paraAlign.sourceParagraphPositions, paraAlign.targetParagraphPositions);
210 }
211
212 if (Util.isEmpty(paraAlign.sourceParagraphPositions)
213 || Util.isEmpty(paraAlign.targetParagraphPositions)) {
214
215 Log.warn(getClass(), "Problematic para alignment -- source: {}, target: {}",
216 paraAlign.sourceParagraphPositions, paraAlign.targetParagraphPositions);
217
218 continue;
219 }
220
221
222 List<ITextUnit> srcTuGroup = new ArrayList<>();
223 List<List<SegmentDataWithCodes>> srcSegDataGroup = new ArrayList<>();
224
225 for (int srcParaIndex : paraAlign.sourceParagraphPositions) {
226 srcTuGroup.add(sourceTUs.get(srcParaIndex));
227 srcSegDataGroup.add(sourceSegmentData.get(srcParaIndex));
228 }
229
230 List<ITextUnit> trgTuGroup = new ArrayList<>();
231 List<List<SegmentDataWithCodes>> trgSegDataGroup = new ArrayList<>();
232
233 for (int trgParaIndex : paraAlign.targetParagraphPositions) {
234 trgTuGroup.add(targetTUs.get(trgParaIndex));
235 trgSegDataGroup.add(targetSegmentData.get(trgParaIndex));
236 }
237
238 applySentenceAlignmentsToParagraphPair(
239 srcTuGroup, trgTuGroup, srcSegDataGroup, trgSegDataGroup,
240 combined.sentenceAlignments);
241 }
242 }
243
244 private void applySentenceAlignmentsToParagraphPair(
245 List<ITextUnit> srcTuGroup, List<ITextUnit> trgTuGroup,
246 List<List<SegmentDataWithCodes>> srcSegDataGroup,
247 List<List<SegmentDataWithCodes>> trgSegDataGroup,
248 List<SentenceAlignment> sentenceAlignments) {
249
250 if (Util.isEmpty(srcTuGroup) || Util.isEmpty(trgTuGroup)) {
251 Log.warn(getClass(), "Problematic para alignment -- source: {}, target: {}",
252 srcTuGroup, trgTuGroup);
253
254 return;
255 }
256
257 if (Util.isEmpty(srcSegDataGroup) || Util.isEmpty(trgSegDataGroup)) {
258 Log.warn(getClass(), "Problematic sentence alignment -- source: {}, target: {}",
259 srcSegDataGroup, trgSegDataGroup);
260
261 return;
262 }
263
264
265 if (srcTuGroup.size() == 1 && trgTuGroup.size() == 1) {
266 applySentenceAlignmentsToSinglePair(
267 srcTuGroup.get(0), trgTuGroup.get(0),
268 srcSegDataGroup.get(0), trgSegDataGroup.get(0),
269 sentenceAlignments);
270
271 return;
272 }
273
274
275 ITextUnit primarySrcTu = srcTuGroup.get(0);
276 List<SegmentDataWithCodes> mergedSrcSegData = new ArrayList<>(srcSegDataGroup.get(0));
277
278 for (int i = 1; i < srcTuGroup.size(); i++) {
279 ITextUnit additionalTu = srcTuGroup.get(i);
280
281 for (Segment seg : additionalTu.getSource().getSegments()) {
282 primarySrcTu.getSource().append(seg.clone());
283 }
284
285 mergedSrcSegData.addAll(srcSegDataGroup.get(i));
286 }
287
288
289 ITextUnit primaryTrgTu = trgTuGroup.isEmpty() ? null : trgTuGroup.get(0);
290
291 List<SegmentDataWithCodes> mergedTrgSegData = trgTuGroup.isEmpty()
292 ? new ArrayList<>()
293 : new ArrayList<>(trgSegDataGroup.get(0));
294
295 if (primaryTrgTu != null && trgTuGroup.size() > 1) {
296 for (int i = 1; i < trgTuGroup.size(); i++) {
297 ITextUnit additionalTu = trgTuGroup.get(i);
298
299 for (Segment seg : additionalTu.getSource().getSegments()) {
300 primaryTrgTu.getSource().append(seg.clone());
301 }
302
303 mergedTrgSegData.addAll(trgSegDataGroup.get(i));
304 }
305 }
306
307 if (primaryTrgTu != null) {
308 applySentenceAlignmentsToSinglePair(
309 primarySrcTu, primaryTrgTu, mergedSrcSegData, mergedTrgSegData,
310 sentenceAlignments);
311 }
312 }
313
314 private void applySentenceAlignmentsToSinglePair(
315 ITextUnit sourceTu, ITextUnit targetTu,
316 List<SegmentDataWithCodes> srcSegData, List<SegmentDataWithCodes> trgSegData,
317 List<SentenceAlignment> sentenceAlignments) {
318
319 TextContainer srcCont = sourceTu.getSource();
320 TextContainer trgCont = sourceTu.createTarget(getTargetLocale(), false, IResource.CREATE_EMPTY);
321 trgCont.clear();
322
323 List<Segment> srcSegments = new ArrayList<>(srcCont.getSegments().asList());
324 int nextSegmentId = 1;
325 int currentSrcPos = 0;
326
327 for (SentenceAlignment align : sentenceAlignments) {
328 if (Util.isEmpty(align.sourcePositions) || Util.isEmpty(align.targetPositions)) {
329
330 Log.warn(getClass(), "Problematic sentence alignment -- source: {}, target: {}",
331 align.sourcePositions, align.targetPositions);
332
333 continue;
334 }
335
336 if (params.isLogAlignmentDetails()) {
337 Log.debug(getClass(), "Sentence alignment: type={}, src={}, trg={}",
338 align.type, align.sourcePositions, align.targetPositions);
339 }
340
341 switch (align.type) {
342 case "MATCH":
343 int srcPos = align.sourcePositions.get(0);
344 int trgPos = align.targetPositions.get(0);
345
346 String segId = srcSegments.get(srcPos).getId();
347 SegmentDataWithCodes trgData = trgSegData.get(trgPos);
348
349
350 TextFragment trgFrag = ConversionUtil.toTextFragment(
351 new LayeredTextX().text(trgData.text).codes(trgData.codes)
352 .language(getTargetLocale().toString()));
353
354 trgCont.append(new Segment(segId, trgFrag));
355 currentSrcPos = srcPos + 1;
356 break;
357
358 case "MULTI_MATCH":
359 int firstSrcPos = align.sourcePositions.get(0);
360 String groupId = srcSegments.get(firstSrcPos).getId();
361
362
363 for (int i = 0; i < align.sourcePositions.size() - 1; i++) {
364 srcCont.getSegments().joinWithNext(firstSrcPos);
365 }
366
367
368 List<SegmentDataWithCodes> trgGroup = new ArrayList<>();
369
370 for (int tPos : align.targetPositions) {
371 trgGroup.add(trgSegData.get(tPos));
372 }
373
374 TextFragment mergedFrag = mergeSegmentDataWithCodes(trgGroup);
375
376 trgCont.append(new Segment(groupId, mergedFrag));
377 currentSrcPos = firstSrcPos + 1;
378 break;
379
380 case "DELETED":
381 int delSrcPos = align.sourcePositions.get(0);
382 String delSrcId = srcSegments.get(delSrcPos).getId();
383
384 trgCont.append(new Segment(delSrcId, new TextFragment("")));
385 currentSrcPos = delSrcPos + 1;
386 break;
387
388 case "INSERTED":
389 int insTrgPos = align.targetPositions.get(0);
390 SegmentDataWithCodes insTrgData = trgSegData.get(insTrgPos);
391
392 String newId = generateUniqueSegmentId(sourceTu, nextSegmentId++);
393
394 TextFragment emptySrcFrag = new TextFragment("");
395 srcCont.getSegments().insert(currentSrcPos, new Segment(newId, emptySrcFrag));
396
397 TextFragment insTrgFrag = ConversionUtil.toTextFragment(
398 new LayeredTextX().text(insTrgData.text).codes(insTrgData.codes)
399 .language(getTargetLocale().toString()));
400
401 trgCont.append(new Segment(newId, insTrgFrag));
402 currentSrcPos++;
403 break;
404 }
405 }
406
407
408 int srcCount = srcCont.getSegments().count();
409 int trgCount = trgCont.getSegments().count();
410
411 if (srcCount != trgCount) {
412 throw new OkapiException(
413 String.format("Segment count mismatch in TU '%s': source=%d, target=%d",
414 sourceTu.getId(), srcCount, trgCount));
415 }
416
417
418 Iterator<Segment> srcIt = srcCont.getSegments().iterator();
419 Iterator<Segment> trgIt = trgCont.getSegments().iterator();
420
421 while (srcIt.hasNext() && trgIt.hasNext()) {
422 Segment srcSeg = srcIt.next();
423 Segment trgSeg = trgIt.next();
424
425 if (!srcSeg.getId().equals(trgSeg.getId())) {
426 throw new OkapiException(
427 String.format("Segment ID mismatch in TU '%s': source='%s', target='%s'",
428 sourceTu.getId(), srcSeg.getId(), trgSeg.getId()));
429 }
430 }
431
432 trgCont.setHasBeenSegmentedFlag(true);
433 trgCont.getSegments().setAlignmentStatus(AlignmentStatus.ALIGNED);
434
435
436 if (params.isLogAlignmentDetails()) {
437 Log.debug(getClass(), "Aligning codes between source and target for TU: {}",
438 sourceTu.getId());
439 }
440
441 srcIt = srcCont.getSegments().iterator();
442 trgIt = trgCont.getSegments().iterator();
443
444 while (srcIt.hasNext() && trgIt.hasNext()) {
445 Segment srcSeg = srcIt.next();
446 Segment trgSeg = trgIt.next();
447
448 if (params.isUseCodesReinsertionModel()) {
449 TextUnitUtil.removeCodes(trgSeg.getContent());
450
451 } else {
452
453 OkapiUtil.removeExtraCodes(srcSeg.getContent().getCodes(), trgSeg.getContent());
454
455
456 TextFragmentUtil.alignAndCopyCodeMetadata(srcSeg.text, trgSeg.text, true, true);
457
458
459 OkapiUtil.rearrangeCodes(srcSeg.getContent().getCodes(), trgSeg.getContent());
460 }
461 }
462 }
463
464 private TextFragment mergeSegmentDataWithCodes(List<SegmentDataWithCodes> segDataList) {
465 if (segDataList.isEmpty()) {
466 return new TextFragment("");
467 }
468
469 if (segDataList.size() == 1) {
470 SegmentDataWithCodes data = segDataList.get(0);
471
472 return ConversionUtil.toTextFragment(
473 new LayeredTextX().text(data.text).codes(data.codes)
474 .language(getTargetLocale().toString()));
475 }
476
477
478 StringBuilder mergedText = new StringBuilder();
479 List<InlineCode> mergedCodes = new ArrayList<>();
480
481 int cumulativeOffset = 0;
482
483 for (int i = 0; i < segDataList.size(); i++) {
484 SegmentDataWithCodes data = segDataList.get(i);
485
486 if (i > 0) {
487 mergedText.append(" ");
488 cumulativeOffset++;
489 }
490
491 mergedText.append(data.text);
492
493
494 for (InlineCode code : data.codes) {
495 InlineCode adjustedCode = new InlineCode();
496 adjustedCode.setId(code.getId());
497 adjustedCode.setPosition(code.getPosition() + cumulativeOffset);
498 adjustedCode.setTagType(code.getTagType());
499 adjustedCode.setType(code.getType());
500 adjustedCode.setData(code.getData());
501 adjustedCode.setOuterData(code.getOuterData());
502 adjustedCode.setFlag(code.getFlag());
503 adjustedCode.setDisplayText(code.getDisplayText());
504 adjustedCode.setOriginalId(code.getOriginalId());
505 mergedCodes.add(adjustedCode);
506 }
507
508 cumulativeOffset += data.text.length();
509 }
510
511 return ConversionUtil.toTextFragment(
512 new LayeredTextX().text(mergedText.toString()).codes(mergedCodes)
513 .language(getTargetLocale().toString()));
514 }
515
516 private String generateUniqueSegmentId(ITextUnit tu, int counter) {
517 String candidateId;
518
519 do {
520 candidateId = tu.getId() + "_seg_" + counter++;
521
522 } while (tu.getSource().getSegments().get(candidateId) != null);
523
524 return candidateId;
525 }
526 }