View Javadoc
1   package com.acumenvelocity.ath.mt.confidence;
2   
3   import com.acumenvelocity.ath.common.AthUtil;
4   import com.acumenvelocity.ath.common.Const;
5   import com.acumenvelocity.ath.common.Log;
6   import com.google.gson.*;
7   
8   import java.io.*;
9   import java.nio.charset.StandardCharsets;
10  import java.nio.file.Files;
11  import java.nio.file.Path;
12  import java.util.*;
13  import java.util.concurrent.*;
14  import java.util.stream.Collectors;
15  
16  public class VertexQualityEstimator implements AutoCloseable {
17  
18    private final ExecutorService executor = Executors.newFixedThreadPool(
19        Math.max(2, Runtime.getRuntime().availableProcessors() - 1));
20  
21    private final Gson gson = new Gson();
22  
23    private static final String PYTHON_CMD = System.getProperty("python.cmd", "python3");
24    private static final String SCRIPT_RESOURCE_PATH = "/eval_qe_production.py";
25  
26    // Temp script extracted from classpath — created once
27    private final Path extractedScriptPath;
28  
29    public VertexQualityEstimator() throws IOException {
30      this.extractedScriptPath = extractScriptFromClasspath();
31    }
32  
33    private Path extractScriptFromClasspath() throws IOException {
34      InputStream is = getClass().getResourceAsStream(SCRIPT_RESOURCE_PATH);
35  
36      if (is == null) {
37        throw new IOException("Python script not found in classpath: " + SCRIPT_RESOURCE_PATH +
38            " — make sure eval_qe_production.py is in src/main/resources/");
39      }
40  
41      Path tempFile = Files.createTempFile("eval_qe_", ".py");
42      tempFile.toFile().deleteOnExit();
43  
44      try (OutputStream os = Files.newOutputStream(tempFile)) {
45        is.transferTo(os);
46      }
47  
48      if (!tempFile.toFile().setExecutable(true)) {
49        Log.warn(getClass(), "Could not set executable flag on temp Python script: {}", tempFile);
50      }
51  
52      Log.info(getClass(), "Successfully extracted Python QE script to: {}", tempFile);
53      return tempFile;
54    }
55  
56    public VertexDataStructures.TranslationComparisonResult evaluateMultiModelTranslations(
57        List<CommonDataStructures.SegmentTranslations> segments,
58        String sourceLang, String targetLang) throws IOException {
59  
60      List<CompletableFuture<VertexDataStructures.ScoredSegmentTranslations>> futures = new ArrayList<>();
61  
62      for (CommonDataStructures.SegmentTranslations seg : segments) {
63        futures.add(CompletableFuture.supplyAsync(() -> evaluateSegment(seg), executor));
64      }
65  
66      List<VertexDataStructures.ScoredSegmentTranslations> scored = futures.stream()
67          .map(CompletableFuture::join)
68          .filter(Objects::nonNull)
69          .collect(Collectors.toList());
70  
71      return new VertexDataStructures.TranslationComparisonResult(scored);
72    }
73  
74    public VertexDataStructures.ScoredSegmentTranslations evaluateSegment(
75        CommonDataStructures.SegmentTranslations seg) {
76      try {
77        JsonArray payload = new JsonArray();
78        List<String> modelIds = new ArrayList<>();
79  
80        for (Map.Entry<String, String> e : seg.getTranslations().entrySet()) {
81          JsonObject obj = new JsonObject();
82          obj.addProperty("prompt", seg.getSourceText());
83          obj.addProperty("response", e.getValue());
84          payload.add(obj);
85          modelIds.add(e.getKey());
86        }
87  
88        ProcessBuilder pb = new ProcessBuilder(PYTHON_CMD, extractedScriptPath.toString());
89        pb.environment().put("GOOGLE_APPLICATION_CREDENTIALS", Const.ATH_GCP_SECRET_FILE);
90  
91        Process p = pb.start();
92  
93        try (OutputStream os = p.getOutputStream()) {
94          os.write(payload.toString().getBytes(StandardCharsets.UTF_8));
95          os.flush();
96        }
97  
98        // Read all output
99        StringBuilder fullOutput = new StringBuilder();
100 
101       try (BufferedReader reader = new BufferedReader(
102           new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
103         String line;
104 
105         while ((line = reader.readLine()) != null) {
106           fullOutput.append(line).append("\n");
107         }
108       }
109 
110       StringBuilder stderr = new StringBuilder();
111 
112       try (BufferedReader errorReader = new BufferedReader(
113           new InputStreamReader(p.getErrorStream(), StandardCharsets.UTF_8))) {
114         String line;
115 
116         while ((line = errorReader.readLine()) != null) {
117           stderr.append(line).append("\n");
118         }
119       }
120 
121       int exitCode = p.waitFor();
122 
123       if (exitCode != 0) {
124         Log.error(getClass(),
125             "Python QE failed for segment {}, exitCode={}\nSTDERR:\n{}\nSTDOUT:\n{}",
126             seg.getSegmentId(), exitCode,
127             stderr.length() > 0 ? stderr.toString().trim() : "(empty)",
128             fullOutput.length() > 0 ? fullOutput.toString().trim() : "(empty)");
129 
130         return null;
131       }
132 
133       // EXTRACT ONLY THE LINE THAT LOOKS LIKE JSON
134       String jsonLine = fullOutput.toString().lines()
135           .filter(line -> line.trim().startsWith("{") && line.trim().endsWith("}"))
136           .findFirst()
137           .orElse(null);
138 
139       if (jsonLine == null) {
140         Log.error(getClass(), "No valid JSON found in Python output for segment {}\nOUTPUT:\n{}",
141             seg.getSegmentId(), fullOutput.toString().trim());
142 
143         return null;
144       }
145 
146       JsonObject result = gson.fromJson(jsonLine, JsonObject.class);
147       List<VertexDataStructures.ScoredTranslation> list = new ArrayList<>();
148 
149       for (int i = 0; i < modelIds.size(); i++) {
150         String key = "alt_" + i;
151         JsonElement altElem = result.get(key);
152 
153         if (altElem == null || !altElem.isJsonObject()) {
154           Log.warn(getClass(), "Missing alt_{} in result for segment {}", i, seg.getSegmentId());
155           continue;
156         }
157 
158         JsonObject altObj = altElem.getAsJsonObject();
159         JsonElement metricxElem = altObj.get("metricx");
160 
161         if (metricxElem == null || !metricxElem.isJsonPrimitive()) {
162           Log.warn(getClass(), "Invalid metricx for alt_{} in segment {}", i, seg.getSegmentId());
163           continue;
164         }
165 
166         double metricx = metricxElem.getAsDouble();
167         String src = seg.getSourceText();
168         String trg = seg.getTranslations().get(modelIds.get(i));
169 
170         Log.info(getClass(), "{} trl by '{}': '{}' from src: '{}'",
171             altObj,
172             AthUtil.extractLastSection(modelIds.get(i)),
173             trg,
174             src);
175 
176         list.add(new VertexDataStructures.ScoredTranslation(
177             modelIds.get(i),
178             seg.getTranslations().get(modelIds.get(i)),
179             metricx,
180             altObj.has("fluency") ? altObj.get("fluency").getAsDouble() : 0.0,
181             altObj.has("coherence") ? altObj.get("coherence").getAsDouble() : 0.0,
182             altObj.has("groundedness") ? altObj.get("groundedness").getAsDouble() : 0.0,
183             HeuristicQualityEstimator.calculateLengthRatio(src, trg),
184             HeuristicQualityEstimator.calculateTextSimilarity(src, trg)));
185       }
186 
187       if (list.isEmpty()) {
188         Log.warn(getClass(), "No valid scores parsed for segment {}", seg.getSegmentId());
189         return null;
190       }
191 
192       return new VertexDataStructures.ScoredSegmentTranslations(
193           seg.getSegmentId(), seg.getSourceText(), seg.getSourceTf(), list);
194 
195     } catch (Exception e) {
196       Log.error(getClass(), e, "Unexpected error in Gemini QE for segment {}", seg.getSegmentId());
197       return null;
198     }
199   }
200 
201   @Override
202   public void close() {
203     executor.shutdownNow();
204 
205     if (extractedScriptPath != null && Files.exists(extractedScriptPath)) {
206       try {
207         Files.deleteIfExists(extractedScriptPath);
208 
209       } catch (Exception ignored) {
210       }
211     }
212   }
213 }