View Javadoc
1   package com.acumenvelocity.ath.steps;
2   
3   import java.util.ArrayList;
4   import java.util.List;
5   
6   import com.acumenvelocity.ath.common.Const;
7   import com.acumenvelocity.ath.common.ConversionUtil;
8   import com.acumenvelocity.ath.common.Log;
9   import com.acumenvelocity.ath.common.OkapiUtil;
10  import com.acumenvelocity.ath.common.OkapiWordBreaker;
11  import com.acumenvelocity.ath.common.OriginalTuAnnotation;
12  import com.acumenvelocity.ath.gemini.GenAi;
13  import com.acumenvelocity.ath.model.InlineCodeRef;
14  import com.acumenvelocity.ath.model.LayeredSegment;
15  import com.acumenvelocity.ath.model.x.LayeredTextX;
16  
17  import net.sf.okapi.common.Event;
18  import net.sf.okapi.common.IResource;
19  import net.sf.okapi.common.Util;
20  import net.sf.okapi.common.resource.ITextUnit;
21  import net.sf.okapi.common.resource.Segment;
22  import net.sf.okapi.common.resource.TextContainer;
23  import net.sf.okapi.common.resource.TextFragment;
24  import net.sf.okapi.common.resource.TextFragmentUtil;
25  
26  public class CodesReinsertionStep extends BaseTuBatchProcessingStep {
27  
28    private final List<LayeredSegment> lsegs = new ArrayList<>();
29    private final List<LayeredTextX> slts = new ArrayList<>();
30    private boolean useCodesReinsertionModel;
31    private String codesReinsertionModelName;
32  
33    public CodesReinsertionStep(boolean useCodesReinsertionModel, String codesReinsertionModelName) {
34      super();
35      this.useCodesReinsertionModel = useCodesReinsertionModel;
36  
37      this.codesReinsertionModelName = Util.isEmpty(codesReinsertionModelName)
38          ? Const.GEMINI_CODE_REINSERTION_MODEL
39          : codesReinsertionModelName;
40    }
41  
42    @Override
43    public String getName() {
44      return "Code Reinsertion Step";
45    }
46  
47    @Override
48    public String getDescription() {
49      return "Restores the source part of TextUnit resources from an annotation, employs a LLM to "
50          + "insert inline codes in the targets.";
51    }
52  
53    @Override
54    protected void clear() {
55      lsegs.clear();
56      slts.clear();
57    }
58  
59    private void preProcessTextUnit(ITextUnit tu) {
60      OriginalTuAnnotation ota = tu.getAnnotation(OriginalTuAnnotation.class);
61  
62      // OriginalTuAnnotation is set by MtLeveragingStep that removes source codes for MT, for
63      // alignment we use the original source codes
64      if (ota != null) {
65        ITextUnit otu = ota.getTu();
66        tu.setSource(otu.getSource());
67      }
68  
69      TextContainer src = tu.getSource();
70      TextContainer trg = tu.getTarget(getTargetLocale());
71  
72      if (src == null) {
73        Log.error(this.getClass(), "Source of '{}' is null", tu.getId());
74        return;
75      }
76  
77      // Target can be null for not-yet-translated segments, warn and continue
78      if (trg == null) {
79        Log.warn(this.getClass(), "Target ''{}'' of '{}' is null, creating an empty target",
80            getTargetLocale(), tu.getId());
81  
82        trg = tu.createTarget(getTargetLocale(), false, IResource.COPY_SEGMENTATION);
83      }
84  
85      if (!useCodesReinsertionModel) {
86        return;
87      }
88  
89      for (Segment sseg : src.getSegments()) {
90        // Don't include in analysis an empty list of source codes
91        if (Util.isEmpty(sseg.getContent().getCodes())) {
92          continue;
93        }
94  
95        Segment tseg = trg.getSegments().get(sseg.getId());
96  
97        LayeredTextX slt = ConversionUtil.toLayeredText(sseg.text, getSourceLocale());
98        LayeredTextX tlt = ConversionUtil.toLayeredText(tseg.text, getTargetLocale());
99  
100       LayeredSegment lseg = ConversionUtil.toLayeredSegment(slt, tlt);
101 
102       lseg.setTrgWordBreakPositions(
103           OkapiWordBreaker.getWordBreakPositions(tlt.getText(), getTargetLocale()));
104 
105       lsegs.add(lseg);
106       slts.add(slt);
107     }
108   }
109 
110   /**
111    * Update the target codes from the LLM API response.
112    * 
113    * @param trgCodesList
114    */
115   private void postProcessTextUnits(List<Event> tuEvents, List<List<InlineCodeRef>> trgCodesList) {
116     int index = 0;
117 
118     for (Event tue : tuEvents) {
119       ITextUnit tu = tue.getTextUnit();
120 
121       TextContainer src = tu.getSource();
122       TextContainer trg = tu.getTarget(getTargetLocale());
123 
124       if (src == null) {
125         Log.error(this.getClass(), "Source of '{}' is null", tu.getId());
126         continue;
127       }
128 
129       if (useCodesReinsertionModel) {
130         for (Segment sseg : src.getSegments()) {
131           Log.trace(this.getClass(), "sseg: '{}'", sseg);
132 
133           // Don't include in analysis an empty list of source codes
134           if (Util.isEmpty(sseg.getContent().getCodes())) {
135             continue;
136           }
137 
138           Segment tseg = trg.getSegments().get(sseg.getId());
139           Log.trace(this.getClass(), "tseg: '{}'", tseg);
140 
141           // Create an empty target segment with the same Id as of the source segment
142           if (tseg == null) {
143             tseg = new Segment(sseg.getId());
144             trg.append(tseg);
145             Log.trace(this.getClass(), "Created a missing tseg: '{}'", tseg);
146           }
147 
148           if (index < trgCodesList.size()) {
149             List<InlineCodeRef> trgCodes = trgCodesList.get(index);
150             LayeredSegment lseg = lsegs.get(index);
151             lseg.setTrgCodes(trgCodes);
152             LayeredTextX slt = slts.get(index);
153             LayeredTextX tlt = ConversionUtil.tltFromLayeredSegment(lseg, slt);
154 
155             try {
156               TextFragment ttf = ConversionUtil.toTextFragment(tlt);
157               tseg.setContent(ttf);
158 
159             } catch (Exception e) {
160               Log.warn(getClass(), "Error converting to TextFragment: {}", e.getMessage());
161               continue;
162             }
163 
164             index++;
165           }
166         }
167       }
168 
169       // Fix codes after AI if needed.
170       // If AI is not used, this code reinserts target codes.
171       for (Segment sseg : src.getSegments()) {
172         Segment tseg = trg.getSegments().get(sseg.getId());
173         
174         if (tseg == null) {
175           continue;
176         }
177 
178         // For empty targets we copy source codes
179         TextFragment segSource = sseg.getContent();
180         TextFragment segTarget = tseg.getContent();
181 
182         // This normalization is needed even after LLM alignment that can hypothetically create
183         // extra codes
184         OkapiUtil.removeExtraCodes(segSource.getCodes(), segTarget);
185 
186         // Align codes and copy metadata from source to target
187         TextFragmentUtil.alignAndCopyCodeMetadata(segSource, segTarget, true, true);
188         
189         // Rearrange opening and closing codes
190         OkapiUtil.rearrangeCodes(segSource.getCodes(), segTarget);
191       }
192     }
193   }
194 
195   @Override
196   protected void processTuEvents(List<Event> tuEvents) {
197     // Process TUs
198     for (Event tue : tuEvents) {
199       ITextUnit tu = tue.getTextUnit();
200       preProcessTextUnit(tu);
201     }
202 
203     List<List<InlineCodeRef>> trgCodesList = useCodesReinsertionModel
204         ? GenAi.reinsertCodes(codesReinsertionModelName, lsegs)
205         : null;
206 
207     postProcessTextUnits(tuEvents, trgCodesList);
208   }
209 }