001// Copyright (c) Choreo contributors
002
003package choreo;
004
005import static edu.wpi.first.util.ErrorMessages.requireNonNullParam;
006
007import choreo.auto.AutoChooser;
008import choreo.auto.AutoFactory;
009import choreo.auto.AutoFactory.AutoBindings;
010import choreo.auto.AutoRoutine;
011import choreo.auto.AutoTrajectory;
012import choreo.trajectory.DifferentialSample;
013import choreo.trajectory.EventMarker;
014import choreo.trajectory.ProjectFile;
015import choreo.trajectory.SwerveSample;
016import choreo.trajectory.Trajectory;
017import choreo.trajectory.TrajectorySample;
018import com.google.gson.Gson;
019import com.google.gson.GsonBuilder;
020import com.google.gson.JsonObject;
021import com.google.gson.JsonSyntaxException;
022import edu.wpi.first.math.geometry.Pose2d;
023import edu.wpi.first.wpilibj.DriverStation;
024import edu.wpi.first.wpilibj.Filesystem;
025import edu.wpi.first.wpilibj2.command.Command;
026import edu.wpi.first.wpilibj2.command.Subsystem;
027import java.io.BufferedReader;
028import java.io.File;
029import java.io.FileNotFoundException;
030import java.io.FileReader;
031import java.io.IOException;
032import java.util.ArrayList;
033import java.util.Arrays;
034import java.util.HashMap;
035import java.util.List;
036import java.util.Map;
037import java.util.Optional;
038import java.util.function.BiConsumer;
039import java.util.function.BooleanSupplier;
040import java.util.function.Supplier;
041
042/** Utilities to load and follow Choreo Trajectories */
043public final class Choreo {
044  private static final Gson GSON =
045      new GsonBuilder()
046          .registerTypeAdapter(EventMarker.class, new EventMarker.Deserializer())
047          .create();
048  private static final String TRAJECTORY_FILE_EXTENSION = ".traj";
049  private static final String SPEC_VERSION = "v2025.0.0";
050
051  private static File CHOREO_DIR = new File(Filesystem.getDeployDirectory(), "choreo");
052
053  private static Optional<ProjectFile> LAZY_PROJECT_FILE = Optional.empty();
054
055  /** This should only be used for unit testing. */
056  static void setChoreoDir(File choreoDir) {
057    CHOREO_DIR = choreoDir;
058  }
059
060  /**
061   * Gets the project file from the deploy directory. Choreolib expects a .chor file to be placed in
062   * src/main/deploy/choreo.
063   *
064   * <p>The result is cached after the first call.
065   *
066   * @return the project file
067   */
068  public static ProjectFile getProjectFile() {
069    if (LAZY_PROJECT_FILE.isPresent()) {
070      return LAZY_PROJECT_FILE.get();
071    }
072    try {
073      // find the first file that ends with a .chor extension
074      File[] projectFiles = CHOREO_DIR.listFiles((dir, name) -> name.endsWith(".chor"));
075      if (projectFiles.length == 0) {
076        throw new RuntimeException("Could not find project file in deploy directory");
077      } else if (projectFiles.length > 1) {
078        throw new RuntimeException("Found multiple project files in deploy directory");
079      }
080      BufferedReader reader = new BufferedReader(new FileReader(projectFiles[0]));
081      String str = reader.lines().reduce("", (a, b) -> a + b);
082      reader.close();
083      JsonObject json = GSON.fromJson(str, JsonObject.class);
084      String version = json.get("version").getAsString();
085      if (!SPEC_VERSION.equals(version)) {
086        throw new RuntimeException(
087            ".chor project file: Wrong version " + version + ". Expected " + SPEC_VERSION);
088      }
089      LAZY_PROJECT_FILE = Optional.of(GSON.fromJson(str, ProjectFile.class));
090    } catch (JsonSyntaxException ex) {
091      throw new RuntimeException("Could not parse project file", ex);
092    } catch (FileNotFoundException ex) {
093      throw new RuntimeException("Could not find project file", ex);
094    } catch (IOException ex) {
095      throw new RuntimeException("Could not find project file", ex);
096    }
097    return LAZY_PROJECT_FILE.get();
098  }
099
100  /**
101   * This interface exists as a type alias. A TrajectoryLogger has a signature of ({@link
102   * Trajectory}, {@link Boolean})-&gt;void, where the function consumes a trajectory and a boolean
103   * indicating whether the trajectory is starting or finishing.
104   *
105   * @param <SampleType> DifferentialSample or SwerveSample.
106   */
107  public interface TrajectoryLogger<SampleType extends TrajectorySample<SampleType>>
108      extends BiConsumer<Trajectory<SampleType>, Boolean> {}
109
110  /** Default constructor. */
111  private Choreo() {
112    throw new UnsupportedOperationException("This is a utility class!");
113  }
114
115  /**
116   * Load a trajectory from the deploy directory. Choreolib expects .traj files to be placed in
117   * src/main/deploy/choreo/[trajectoryName].traj.
118   *
119   * @param <SampleType> The type of samples in the trajectory.
120   * @param trajectoryName The path name in Choreo, which matches the file name in the deploy
121   *     directory, file extension is optional.
122   * @return The loaded trajectory, or `Optional.empty()` if the trajectory could not be loaded.
123   */
124  @SuppressWarnings("unchecked")
125  public static <SampleType extends TrajectorySample<SampleType>>
126      Optional<Trajectory<SampleType>> loadTrajectory(String trajectoryName) {
127    requireNonNullParam(trajectoryName, "trajectoryName", "Choreo.loadTrajectory");
128
129    if (trajectoryName.endsWith(TRAJECTORY_FILE_EXTENSION)) {
130      trajectoryName =
131          trajectoryName.substring(0, trajectoryName.length() - TRAJECTORY_FILE_EXTENSION.length());
132    }
133    File trajectoryFile = new File(CHOREO_DIR, trajectoryName + TRAJECTORY_FILE_EXTENSION);
134    try {
135      var reader = new BufferedReader(new FileReader(trajectoryFile));
136      String str = reader.lines().reduce("", (a, b) -> a + b);
137      reader.close();
138      Trajectory<SampleType> trajectory =
139          (Trajectory<SampleType>) loadTrajectoryString(str, getProjectFile());
140      return Optional.of(trajectory);
141    } catch (FileNotFoundException ex) {
142      DriverStation.reportError("Could not find trajectory file: " + trajectoryFile, false);
143    } catch (JsonSyntaxException ex) {
144      DriverStation.reportError("Could not parse trajectory file: " + trajectoryFile, false);
145    } catch (Exception ex) {
146      DriverStation.reportError(ex.getMessage(), ex.getStackTrace());
147    }
148    return Optional.empty();
149  }
150
151  /**
152   * Load a trajectory from a string.
153   *
154   * @param trajectoryJsonString The JSON string.
155   * @param projectFile The project file.
156   * @return The loaded trajectory, or `empty std::optional` if the trajectory could not be loaded.
157   */
158  static Trajectory<? extends TrajectorySample<?>> loadTrajectoryString(
159      String trajectoryJsonString, ProjectFile projectFile) {
160    JsonObject wholeTrajectory = GSON.fromJson(trajectoryJsonString, JsonObject.class);
161    String name = wholeTrajectory.get("name").getAsString();
162    String version = wholeTrajectory.get("version").getAsString();
163    if (!SPEC_VERSION.equals(version)) {
164      throw new RuntimeException(
165          name + ".traj: Wrong version: " + version + ". Expected " + SPEC_VERSION);
166    }
167    // Filter out markers with negative timestamps or empty names
168    List<EventMarker> unfilteredEvents =
169        new ArrayList<EventMarker>(
170            Arrays.asList(GSON.fromJson(wholeTrajectory.get("events"), EventMarker[].class)));
171    unfilteredEvents.removeIf(marker -> marker.timestamp < 0 || marker.event.length() == 0);
172    EventMarker[] events = new EventMarker[unfilteredEvents.size()];
173    unfilteredEvents.toArray(events);
174
175    JsonObject trajectoryObj = wholeTrajectory.getAsJsonObject("trajectory");
176    Integer[] splits = GSON.fromJson(trajectoryObj.get("splits"), Integer[].class);
177    if (splits.length == 0 || splits[0] != 0) {
178      Integer[] newArray = new Integer[splits.length + 1];
179      newArray[0] = 0;
180      System.arraycopy(splits, 0, newArray, 1, splits.length);
181      splits = newArray;
182    }
183    if (projectFile.type.equals("Swerve")) {
184      SwerveSample[] samples = GSON.fromJson(trajectoryObj.get("samples"), SwerveSample[].class);
185      return new Trajectory<SwerveSample>(name, List.of(samples), List.of(splits), List.of(events));
186    } else if (projectFile.type.equals("Differential")) {
187      DifferentialSample[] sampleArray =
188          GSON.fromJson(trajectoryObj.get("samples"), DifferentialSample[].class);
189      return new Trajectory<DifferentialSample>(
190          name, List.of(sampleArray), List.of(splits), List.of(events));
191    } else {
192      throw new RuntimeException("Unknown project type: " + projectFile.type);
193    }
194  }
195
196  /**
197   * A utility for caching loaded trajectories. This allows for loading trajectories only once, and
198   * then reusing them.
199   */
200  public static class TrajectoryCache {
201    private final Map<String, Trajectory<?>> cache;
202
203    /** Creates a new TrajectoryCache with a normal {@link HashMap} as the cache. */
204    public TrajectoryCache() {
205      cache = new HashMap<>();
206    }
207
208    /**
209     * Creates a new TrajectoryCache with a custom cache.
210     *
211     * <p>this could be useful if you want to use a concurrent map or a map with a maximum size.
212     *
213     * @param cache The cache to use.
214     */
215    public TrajectoryCache(Map<String, Trajectory<?>> cache) {
216      requireNonNullParam(cache, "cache", "TrajectoryCache.<init>");
217      this.cache = cache;
218    }
219
220    /**
221     * Load a trajectory from the deploy directory. Choreolib expects .traj files to be placed in
222     * src/main/deploy/choreo/[trajectoryName].traj.
223     *
224     * <p>This method will cache the loaded trajectory and reused it if it is requested again.
225     *
226     * @param trajectoryName the path name in Choreo, which matches the file name in the deploy
227     *     directory, file extension is optional.
228     * @return the loaded trajectory, or `Optional.empty()` if the trajectory could not be loaded.
229     * @see Choreo#loadTrajectory(String)
230     */
231    public Optional<? extends Trajectory<?>> loadTrajectory(String trajectoryName) {
232      requireNonNullParam(trajectoryName, "trajectoryName", "TrajectoryCache.loadTrajectory");
233      if (cache.containsKey(trajectoryName)) {
234        return Optional.of(cache.get(trajectoryName));
235      } else {
236        return Choreo.loadTrajectory(trajectoryName)
237            .map(
238                trajectory -> {
239                  cache.put(trajectoryName, trajectory);
240                  return trajectory;
241                });
242      }
243    }
244
245    /**
246     * Load a section of a split trajectory from the deploy directory. Choreolib expects .traj files
247     * to be placed in src/main/deploy/choreo/[trajectoryName].traj.
248     *
249     * <p>This method will cache the loaded trajectory and reused it if it is requested again. The
250     * trajectory that is split off of will also be cached.
251     *
252     * @param trajectoryName the path name in Choreo, which matches the file name in the deploy
253     *     directory, file extension is optional.
254     * @param splitIndex the index of the split trajectory to load
255     * @return the loaded trajectory, or `Optional.empty()` if the trajectory could not be loaded.
256     * @see Choreo#loadTrajectory(String)
257     */
258    public Optional<? extends Trajectory<?>> loadTrajectory(String trajectoryName, int splitIndex) {
259      requireNonNullParam(trajectoryName, "trajectoryName", "TrajectoryCache.loadTrajectory");
260      // make the key something that could never possibly be a valid trajectory name
261      String key = trajectoryName + ".:." + splitIndex;
262      if (cache.containsKey(key)) {
263        return Optional.of(cache.get(key));
264      } else if (cache.containsKey(trajectoryName)) {
265        return cache
266            .get(trajectoryName)
267            .getSplit(splitIndex)
268            .map(
269                trajectory -> {
270                  cache.put(key, trajectory);
271                  return trajectory;
272                });
273      } else {
274        return Choreo.loadTrajectory(trajectoryName)
275            .flatMap(
276                trajectory -> {
277                  cache.put(trajectoryName, trajectory);
278                  return trajectory
279                      .getSplit(splitIndex)
280                      .map(
281                          split -> {
282                            cache.put(key, split);
283                            return split;
284                          });
285                });
286      }
287    }
288
289    /** Clear the cache. */
290    public void clear() {
291      cache.clear();
292    }
293  }
294
295  /**
296   * Create a factory that can be used to create {@link AutoRoutine} and {@link AutoTrajectory}.
297   *
298   * @param <SampleType> The type of samples in the trajectory.
299   * @param poseSupplier A function that returns the current field-relative {@link Pose2d} of the
300   *     robot.
301   * @param controller A {@link BiConsumer} to follow the current {@link Trajectory}&lt;{@link
302   *     SampleType}&gt;.
303   * @param mirrorTrajectory If this returns true, the path will be mirrored to the opposite side,
304   *     while keeping the same coordinate system origin. This will be called every loop during the
305   *     command.
306   * @param driveSubsystem The drive {@link Subsystem} to require for {@link AutoTrajectory} {@link
307   *     Command}s.
308   * @param bindings Universal trajectory event bindings.
309   * @return An {@link AutoFactory} that can be used to create {@link AutoRoutine} and {@link
310   *     AutoTrajectory}.
311   * @see AutoChooser using this factory with AutoChooser to generate auto routines.
312   */
313  public static <SampleType extends TrajectorySample<SampleType>> AutoFactory createAutoFactory(
314      Supplier<Pose2d> poseSupplier,
315      BiConsumer<Pose2d, SampleType> controller,
316      BooleanSupplier mirrorTrajectory,
317      Subsystem driveSubsystem,
318      AutoBindings bindings) {
319    return new AutoFactory(
320        requireNonNullParam(poseSupplier, "poseSupplier", "Choreo.createAutoFactory"),
321        requireNonNullParam(controller, "controller", "Choreo.createAutoFactory"),
322        requireNonNullParam(mirrorTrajectory, "mirrorTrajectory", "Choreo.createAutoFactory"),
323        requireNonNullParam(driveSubsystem, "driveSubsystem", "Choreo.createAutoFactory"),
324        requireNonNullParam(bindings, "bindings", "Choreo.createAutoFactory"),
325        Optional.empty());
326  }
327
328  /**
329   * Create a factory that can be used to create {@link AutoRoutine} and {@link AutoTrajectory}.
330   *
331   * @param <SampleType> The type of samples in the trajectory.
332   * @param poseSupplier A function that returns the current field-relative {@link Pose2d} of the
333   *     robot.
334   * @param controller A {@link BiConsumer} to follow the current {@link Trajectory}&lt;{@link
335   *     SampleType}&gt;.
336   * @param mirrorTrajectory If this returns true, the path will be mirrored to the opposite side,
337   *     while keeping the same coordinate system origin. This will be called every loop during the
338   *     command.
339   * @param driveSubsystem The drive {@link Subsystem} to require for {@link AutoTrajectory} {@link
340   *     Command}s.
341   * @param bindings Universal trajectory event bindings.
342   * @param trajectoryLogger A {@link TrajectoryLogger} to log {@link Trajectory} as they start and
343   *     finish.
344   * @return An {@link AutoFactory} that can be used to create {@link AutoRoutine} and {@link
345   *     AutoTrajectory}.
346   * @see AutoChooser using this factory with AutoChooser to generate auto routines.
347   */
348  public static <SampleType extends TrajectorySample<SampleType>> AutoFactory createAutoFactory(
349      Supplier<Pose2d> poseSupplier,
350      BiConsumer<Pose2d, SampleType> controller,
351      BooleanSupplier mirrorTrajectory,
352      Subsystem driveSubsystem,
353      AutoBindings bindings,
354      TrajectoryLogger<SampleType> trajectoryLogger) {
355    return new AutoFactory(
356        requireNonNullParam(poseSupplier, "poseSupplier", "Choreo.createAutoFactory"),
357        requireNonNullParam(controller, "controller", "Choreo.createAutoFactory"),
358        requireNonNullParam(mirrorTrajectory, "mirrorTrajectory", "Choreo.createAutoFactory"),
359        requireNonNullParam(driveSubsystem, "driveSubsystem", "Choreo.createAutoFactory"),
360        requireNonNullParam(bindings, "bindings", "Choreo.createAutoFactory"),
361        Optional.of(trajectoryLogger));
362  }
363}