001// Copyright (c) Choreo contributors
002
003package choreo.auto;
004
005import choreo.Choreo;
006import choreo.Choreo.TrajectoryLogger;
007import choreo.auto.AutoFactory.AutoBindings;
008import choreo.trajectory.DifferentialSample;
009import choreo.trajectory.SwerveSample;
010import choreo.trajectory.Trajectory;
011import choreo.trajectory.TrajectorySample;
012import choreo.util.AllianceFlipUtil;
013import edu.wpi.first.math.geometry.Pose2d;
014import edu.wpi.first.math.geometry.Translation2d;
015import edu.wpi.first.math.util.Units;
016import edu.wpi.first.wpilibj.DriverStation;
017import edu.wpi.first.wpilibj.Timer;
018import edu.wpi.first.wpilibj2.command.Command;
019import edu.wpi.first.wpilibj2.command.FunctionalCommand;
020import edu.wpi.first.wpilibj2.command.Subsystem;
021import edu.wpi.first.wpilibj2.command.button.Trigger;
022import java.util.Optional;
023import java.util.function.BiConsumer;
024import java.util.function.BooleanSupplier;
025import java.util.function.Supplier;
026
027/**
028 * A class that represents a trajectory that can be used in an autonomous routine and have triggers
029 * based off of it.
030 */
031public class AutoTrajectory {
032  // For any devs looking through this class wondering
033  // about all the type casting and `?` for generics it's intentional.
034  // My goal was to make the sample type minimally leak into user code
035  // so you don't have to retype the sample type everywhere in your auto
036  // code. This also makes the places with generics exposed to users few
037  // and far between. This helps with more novice users
038
039  private static final double DEFAULT_TOLERANCE_METERS = Units.inchesToMeters(3);
040
041  private final String name;
042  private final Trajectory<? extends TrajectorySample<?>> trajectory;
043  private final TrajectoryLogger<? extends TrajectorySample<?>> trajectoryLogger;
044  private final Supplier<Pose2d> poseSupplier;
045  private final BiConsumer<Pose2d, ? extends TrajectorySample<?>> controller;
046  private final BooleanSupplier mirrorTrajectory;
047  private final Timer timer = new Timer();
048  private final Subsystem driveSubsystem;
049  private final AutoRoutine routine;
050
051  /**
052   * A way to create slightly less triggers for alot of actions. Not static as to not leak triggers
053   * made here into another static EventLoop.
054   */
055  private final Trigger offTrigger;
056
057  /** If this trajectory us currently running */
058  private boolean isActive = false;
059
060  /** If the trajectory ran to completion */
061  private boolean isCompleted = false;
062
063  /**
064   * Constructs an AutoTrajectory.
065   *
066   * @param name The trajectory name.
067   * @param trajectory The trajectory samples.
068   * @param poseSupplier The pose supplier.
069   * @param controller The controller function.
070   * @param mirrorTrajectory Getter that determines whether to mirror trajectory.
071   * @param trajectoryLogger Optional trajectory logger.
072   * @param driveSubsystem Drive subsystem.
073   * @param routine Event loop.
074   * @param bindings {@link Choreo#createAutoFactory}
075   */
076  <SampleType extends TrajectorySample<SampleType>> AutoTrajectory(
077      String name,
078      Trajectory<SampleType> trajectory,
079      Supplier<Pose2d> poseSupplier,
080      BiConsumer<Pose2d, SampleType> controller,
081      BooleanSupplier mirrorTrajectory,
082      Optional<TrajectoryLogger<SampleType>> trajectoryLogger,
083      Subsystem driveSubsystem,
084      AutoRoutine routine,
085      AutoBindings bindings) {
086    this.name = name;
087    this.trajectory = trajectory;
088    this.poseSupplier = poseSupplier;
089    this.controller = controller;
090    this.mirrorTrajectory = mirrorTrajectory;
091    this.driveSubsystem = driveSubsystem;
092    this.routine = routine;
093    this.offTrigger = new Trigger(routine.loop(), () -> false);
094    this.trajectoryLogger =
095        trajectoryLogger.isPresent()
096            ? trajectoryLogger.get()
097            : new TrajectoryLogger<SampleType>() {
098              public void accept(Trajectory<SampleType> t, Boolean u) {}
099            };
100
101    bindings.getBindings().forEach((key, value) -> active().and(atTime(key)).onTrue(value));
102  }
103
104  @SuppressWarnings("unchecked")
105  private void logTrajectory(boolean starting) {
106    TrajectorySample<?> sample = trajectory.getInitialSample(false);
107    if (sample == null) {
108      return;
109    } else if (sample instanceof SwerveSample) {
110      TrajectoryLogger<SwerveSample> swerveLogger =
111          (TrajectoryLogger<SwerveSample>) trajectoryLogger;
112      Trajectory<SwerveSample> swerveTrajectory = (Trajectory<SwerveSample>) trajectory;
113      swerveLogger.accept(swerveTrajectory, starting);
114    } else if (sample instanceof DifferentialSample) {
115      TrajectoryLogger<DifferentialSample> differentialLogger =
116          (TrajectoryLogger<DifferentialSample>) trajectoryLogger;
117      Trajectory<DifferentialSample> differentialTrajectory =
118          (Trajectory<DifferentialSample>) trajectory;
119      differentialLogger.accept(differentialTrajectory, starting);
120    }
121    ;
122  }
123
124  private void cmdInitialize() {
125    timer.restart();
126    isActive = true;
127    isCompleted = false;
128    logTrajectory(true);
129  }
130
131  @SuppressWarnings("unchecked")
132  private void cmdExecute() {
133    var sample = trajectory.sampleAt(timer.get(), mirrorTrajectory.getAsBoolean());
134    if (sample instanceof SwerveSample swerveSample) {
135      var swerveController = (BiConsumer<Pose2d, SwerveSample>) this.controller;
136      swerveController.accept(poseSupplier.get(), swerveSample);
137    } else if (sample instanceof DifferentialSample differentialSample) {
138      var differentialController = (BiConsumer<Pose2d, DifferentialSample>) this.controller;
139      differentialController.accept(poseSupplier.get(), differentialSample);
140    }
141  }
142
143  private void cmdEnd(boolean interrupted) {
144    timer.stop();
145    isActive = false;
146    isCompleted = !interrupted;
147    cmdExecute(); // will force the controller to be fed the final sample
148    logTrajectory(false);
149  }
150
151  private boolean cmdIsFinished() {
152    return timer.get() > trajectory.getTotalTime() || !routine.isActive;
153  }
154
155  /**
156   * Creates a command that allocates the drive subsystem and follows the trajectory using the
157   * factories control function
158   *
159   * @return The command that will follow the trajectory
160   */
161  public Command cmd() {
162    // if the trajectory is empty, return a command that will print an error
163    if (trajectory.samples().isEmpty()) {
164      return driveSubsystem
165          .runOnce(
166              () -> {
167                DriverStation.reportError("Trajectory " + name + " has no samples", false);
168              })
169          .withName("Trajectory_" + name);
170    }
171    return new FunctionalCommand(
172            this::cmdInitialize,
173            this::cmdExecute,
174            this::cmdEnd,
175            this::cmdIsFinished,
176            driveSubsystem)
177        .withName("Trajectory_" + name);
178  }
179
180  /**
181   * Will get the underlying {@link Trajectory} object.
182   *
183   * <p><b>WARNING:</b> This method is not type safe and should be used with caution. The sample
184   * type of the trajectory should be known before calling this method.
185   *
186   * @param <SampleType> The type of the trajectory samples.
187   * @return The underlying {@link Trajectory} object.
188   */
189  @SuppressWarnings("unchecked")
190  public <SampleType extends TrajectorySample<SampleType>>
191      Trajectory<SampleType> getRawTrajectory() {
192    return (Trajectory<SampleType>) trajectory;
193  }
194
195  /**
196   * Will get the starting pose of the trajectory.
197   *
198   * <p>This position is mirrored based on the {@code mirrorTrajectory} boolean supplier in the
199   * factory used to make this trajectory
200   *
201   * @return The starting pose
202   */
203  public Optional<Pose2d> getInitialPose() {
204    if (trajectory.samples().isEmpty()) {
205      return Optional.empty();
206    }
207    return Optional.ofNullable(trajectory.getInitialPose(mirrorTrajectory.getAsBoolean()));
208  }
209
210  /**
211   * Will get the ending pose of the trajectory.
212   *
213   * <p>This position is mirrored based on the {@code mirrorTrajectory} boolean supplier in the
214   * factory used to make this trajectory
215   *
216   * @return The starting pose
217   */
218  public Optional<Pose2d> getFinalPose() {
219    if (trajectory.samples().isEmpty()) {
220      return Optional.empty();
221    }
222    return Optional.ofNullable(trajectory.getFinalPose(mirrorTrajectory.getAsBoolean()));
223  }
224
225  /**
226   * Returns a trigger that is true while the trajectory is scheduled.
227   *
228   * @return A trigger that is true while the trajectory is scheduled.
229   */
230  public Trigger active() {
231    return new Trigger(routine.loop(), () -> this.isActive && routine.isActive);
232  }
233
234  /**
235   * Returns a trigger that is true while the command is not scheduled.
236   *
237   * <p>The same as calling <code>active().negate()</code>.
238   *
239   * @return A trigger that is true while the command is not scheduled.
240   */
241  public Trigger inactive() {
242    return active().negate();
243  }
244
245  /**
246   * Returns a trigger that rises to true when the trajectory ends and falls when another trajectory
247   * is run.
248   *
249   * <p>This is different from inactive() in a few ways.
250   *
251   * <ul>
252   *   <li>This will never be true if the trajectory is interupted
253   *   <li>This will never be true before the trajectory is run
254   *   <li>This will fall the next cycle after the trajectory ends
255   * </ul>
256   *
257   * <p>Why does the trigger need to fall?
258   *
259   * <pre><code>
260   * //Lets say we had this code segment
261   * Trigger hasGamepiece = ...;
262   * Trigger noGamepiece = hasGamepiece.negate();
263   *
264   * AutoTrajectory rushMidTraj = ...;
265   * AutoTrajectory goShootGamepiece = ...;
266   * AutoTrajectory pickupAnotherGamepiece = ...;
267   *
268   * routine.enabled().onTrue(rushMidTraj.cmd());
269   *
270   * rushMidTraj.done(10).and(noGamepiece).onTrue(pickupAnotherGamepiece.cmd());
271   * rushMidTraj.done(10).and(hasGamepiece).onTrue(goShootGamepiece.cmd());
272   *
273   * // If done never falls when a new trajectory is scheduled
274   * // then these triggers leak into the next trajectory, causing the next note pickup
275   * // to trigger goShootGamepiece.cmd() even if we no longer care about these checks
276   * </code></pre>
277   *
278   * @param cyclesToDelay The number of cycles to delay the trigger from rising to true.
279   * @return A trigger that is true when the trajectoy is finished.
280   */
281  public Trigger done(int cyclesToDelay) {
282    BooleanSupplier checker =
283        new BooleanSupplier() {
284          /** The last used value for trajectory completeness */
285          boolean lastCompleteness = false;
286
287          /** The cycle to be true for */
288          int cycleTarget = -1;
289
290          @Override
291          public boolean getAsBoolean() {
292            if (!isCompleted) {
293              // update last seen value
294              lastCompleteness = false;
295              return false;
296            }
297            if (!lastCompleteness && isCompleted) {
298              // if just flipped to completed update last seen value
299              // and store the cycleTarget based of the current cycle
300              lastCompleteness = true;
301              cycleTarget = routine.pollCount() + cyclesToDelay;
302            }
303            // finally if check the cycle matches the target
304            return routine.pollCount() == cycleTarget;
305          }
306        };
307    return inactive().and(new Trigger(routine.loop(), checker));
308  }
309
310  /**
311   * Returns a trigger that rises to true when the trajectory ends and falls when another trajectory
312   * is run.
313   *
314   * <p>This is different from inactive() in a few ways.
315   *
316   * <ul>
317   *   <li>This will never be true if the trajectory is interupted
318   *   <li>This will never be true before the trajectory is run
319   *   <li>This will fall the next cycle after the trajectory ends
320   * </ul>
321   *
322   * <p>Why does the trigger need to fall?
323   *
324   * <pre><code>
325   * //Lets say we had this code segment
326   * Trigger hasGamepiece = ...;
327   * Trigger noGamepiece = hasGamepiece.negate();
328   *
329   * AutoTrajectory rushMidTraj = ...;
330   * AutoTrajectory goShootGamepiece = ...;
331   * AutoTrajectory pickupAnotherGamepiece = ...;
332   *
333   * routine.enabled().onTrue(rushMidTraj.cmd());
334   *
335   * rushMidTraj.done().and(noGamepiece).onTrue(pickupAnotherGamepiece.cmd());
336   * rushMidTraj.done().and(hasGamepiece).onTrue(goShootGamepiece.cmd());
337   *
338   * // If done never falls when a new trajectory is scheduled
339   * // then these triggers leak into the next trajectory, causing the next note pickup
340   * // to trigger goShootGamepiece.cmd() even if we no longer care about these checks
341   * </code></pre>
342   *
343   * @return A trigger that is true when the trajectoy is finished.
344   */
345  public Trigger done() {
346    return done(0);
347  }
348
349  /**
350   * Returns a trigger that will go true for 1 cycle when the desired time has elapsed
351   *
352   * @param timeSinceStart The time since the command started in seconds.
353   * @return A trigger that is true when timeSinceStart has elapsed.
354   */
355  public Trigger atTime(double timeSinceStart) {
356    // The timer shhould never be negative so report this as a warning
357    if (timeSinceStart < 0) {
358      DriverStation.reportWarning("Trigger time cannot be negative for " + name, true);
359      return offTrigger;
360    }
361
362    // The timer should never exceed the total trajectory time so report this as a warning
363    if (timeSinceStart > trajectory.getTotalTime()) {
364      DriverStation.reportWarning(
365          "Trigger time cannot be greater than total trajectory time for " + name, true);
366      return offTrigger;
367    }
368
369    // Make the trigger only be high for 1 cycle when the time has elapsed,
370    // this is needed for better support of multi-time triggers for multi events
371    return new Trigger(
372        routine.loop(),
373        new BooleanSupplier() {
374          double lastTimestamp = timer.get();
375
376          public boolean getAsBoolean() {
377            double nowTimestamp = timer.get();
378            try {
379              return lastTimestamp < nowTimestamp && nowTimestamp >= timeSinceStart;
380            } finally {
381              lastTimestamp = nowTimestamp;
382            }
383          }
384        });
385  }
386
387  /**
388   * Returns a trigger that is true when the event with the given name has been reached based on
389   * time.
390   *
391   * <p>A warning will be printed to the DriverStation if the event is not found and the trigger
392   * will always be false.
393   *
394   * @param eventName The name of the event.
395   * @return A trigger that is true when the event with the given name has been reached based on
396   *     time.
397   */
398  public Trigger atTime(String eventName) {
399    boolean foundEvent = false;
400    Trigger trig = offTrigger;
401
402    for (var event : trajectory.getEvents(eventName)) {
403      // This could create alot of objects, could be done a more efficient way
404      // with having it all be 1 trigger that just has a list of times and checks each one each
405      // cycle
406      // or something like that. If choreo starts proposing memory issues we can look into this.
407      trig = trig.or(atTime(event.timestamp));
408      foundEvent = true;
409    }
410
411    // The user probably expects an event to exist if they're trying to do something at that event,
412    // report the missing event.
413    if (!foundEvent) {
414      DriverStation.reportWarning("Event \"" + eventName + "\" not found for " + name, true);
415    }
416
417    return trig;
418  }
419
420  /**
421   * Returns a trigger that is true when the robot is within toleranceMeters of the given pose.
422   *
423   * <p>This position is mirrored based on the {@code mirrorTrajectory} boolean supplier in the
424   * factory used to make this trajectory.
425   *
426   * @param pose The pose to check against
427   * @param toleranceMeters The tolerance in meters.
428   * @return A trigger that is true when the robot is within toleranceMeters of the given pose.
429   */
430  public Trigger atPose(Pose2d pose, double toleranceMeters) {
431    Translation2d checkedTrans =
432        mirrorTrajectory.getAsBoolean()
433            ? AllianceFlipUtil.flip(pose.getTranslation())
434            : pose.getTranslation();
435    return new Trigger(
436        routine.loop(),
437        () -> {
438          Translation2d currentTrans = poseSupplier.get().getTranslation();
439          return currentTrans.getDistance(checkedTrans) < toleranceMeters;
440        });
441  }
442
443  /**
444   * Returns a trigger that is true when the robot is within toleranceMeters of the given events
445   * pose.
446   *
447   * <p>A warning will be printed to the DriverStation if the event is not found and the trigger
448   * will always be false.
449   *
450   * @param eventName The name of the event.
451   * @param toleranceMeters The tolerance in meters.
452   * @return A trigger that is true when the robot is within toleranceMeters of the given events
453   *     pose.
454   */
455  public Trigger atPose(String eventName, double toleranceMeters) {
456    boolean foundEvent = false;
457    Trigger trig = offTrigger;
458
459    for (var event : trajectory.getEvents(eventName)) {
460      // This could create alot of objects, could be done a more efficient way
461      // with having it all be 1 trigger that just has a list of possess and checks each one each
462      // cycle or something like that.
463      // If choreo starts proposing memory issues we can look into this.
464      Pose2d pose = trajectory.sampleAt(event.timestamp, mirrorTrajectory.getAsBoolean()).getPose();
465      trig = trig.or(atPose(pose, toleranceMeters));
466      foundEvent = true;
467    }
468
469    // The user probably expects an event to exist if they're trying to do something at that event,
470    // report the missing event.
471    if (!foundEvent) {
472      DriverStation.reportWarning("Event \"" + eventName + "\" not found for " + name, true);
473    }
474
475    return trig;
476  }
477
478  /**
479   * Returns a trigger that is true when the robot is within 3 inches of the given events pose.
480   *
481   * <p>A warning will be printed to the DriverStation if the event is not found and the trigger
482   * will always be false.
483   *
484   * @param eventName The name of the event.
485   * @return A trigger that is true when the robot is within 3 inches of the given events pose.
486   */
487  public Trigger atPose(String eventName) {
488    return atPose(eventName, DEFAULT_TOLERANCE_METERS);
489  }
490
491  /**
492   * Returns a trigger that is true when the event with the given name has been reached based on
493   * time and the robot is within toleranceMeters of the given events pose.
494   *
495   * <p>A warning will be printed to the DriverStation if the event is not found and the trigger
496   * will always be false.
497   *
498   * @param eventName The name of the event.
499   * @param toleranceMeters The tolerance in meters.
500   * @return A trigger that is true when the event with the given name has been reached based on
501   *     time and the robot is within toleranceMeters of the given events pose.
502   */
503  public Trigger atTimeAndPose(String eventName, double toleranceMeters) {
504    return atTime(eventName).and(atPose(eventName, toleranceMeters));
505  }
506
507  /**
508   * Returns a trigger that is true when the event with the given name has been reached based on
509   * time and the robot is within 3 inches of the given events pose.
510   *
511   * <p>A warning will be printed to the DriverStation if the event is not found and the trigger
512   * will always be false.
513   *
514   * @param eventName The name of the event.
515   * @return A trigger that is true when the event with the given name has been reached based on
516   *     time and the robot is within 3 inches of the given events pose.
517   */
518  public Trigger atTimeAndPose(String eventName) {
519    return atTimeAndPose(eventName, DEFAULT_TOLERANCE_METERS);
520  }
521
522  /**
523   * Returns an array of all the timestamps of the events with the given name.
524   *
525   * @param eventName The name of the event.
526   * @return An array of all the timestamps of the events with the given name.
527   */
528  public double[] collectEventTimes(String eventName) {
529    return trajectory.getEvents(eventName).stream().mapToDouble(e -> e.timestamp).toArray();
530  }
531
532  /**
533   * Returns an array of all the poses of the events with the given name.
534   *
535   * @param eventName The name of the event.
536   * @return An array of all the poses of the events with the given name.
537   */
538  public Pose2d[] collectEventPoses(String eventName) {
539    var times = collectEventTimes(eventName);
540    var poses = new Pose2d[times.length];
541    for (int i = 0; i < times.length; i++) {
542      poses[i] = trajectory.sampleAt(times[i], mirrorTrajectory.getAsBoolean()).getPose();
543    }
544    return poses;
545  }
546
547  @Override
548  public boolean equals(Object obj) {
549    return obj instanceof AutoTrajectory traj && name.equals(traj.name);
550  }
551}