/*
 * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
 * Licensed under the MIT License.
 */
package ai.onnxruntime;

import java.io.IOException;
import java.util.logging.Logger;

/** An abstract base class for execution provider options classes. */
// Note this lives in ai.onnxruntime to allow subclasses to access the OnnxRuntime.ortApiHandle
// package private field.
public abstract class OrtProviderOptions implements AutoCloseable {
  private static final Logger logger = Logger.getLogger(OrtProviderOptions.class.getName());

  static {
    try {
      OnnxRuntime.init();
    } catch (IOException e) {
      throw new RuntimeException("Failed to load onnx-runtime library", e);
    }
  }

  /** The native pointer. */
  protected final long nativeHandle;

  /** Is the native object closed? */
  protected boolean closed;

  /**
   * Constructs a OrtProviderOptions wrapped around a native pointer.
   *
   * @param nativeHandle The native pointer.
   */
  protected OrtProviderOptions(long nativeHandle) {
    this.nativeHandle = nativeHandle;
    this.closed = false;
  }

  /**
   * Allow access to the api handle pointer for subclasses.
   *
   * @return The api handle.
   */
  protected static long getApiHandle() {
    return OnnxRuntime.ortApiHandle;
  }

  /**
   * Gets the provider enum for this options instance.
   *
   * @return The provider enum.
   */
  public abstract OrtProvider getProvider();

  /**
   * Applies the Java side configuration to the native side object.
   *
   * @throws OrtException If the native call failed.
   */
  protected abstract void applyToNative() throws OrtException;

  /**
   * Is the native object closed?
   *
   * @return True if the native object has been released.
   */
  public synchronized boolean isClosed() {
    return closed;
  }

  @Override
  public void close() {
    if (!closed) {
      close(OnnxRuntime.ortApiHandle, nativeHandle);
      closed = true;
    } else {
      logger.warning("Closing an already closed tensor.");
    }
  }

  /** Checks if the OrtProviderOptions is closed, if so throws {@link IllegalStateException}. */
  protected void checkClosed() {
    if (closed) {
      throw new IllegalStateException("Trying to use a closed OrtProviderOptions");
    }
  }

  /**
   * Native close method.
   *
   * @param apiHandle The api pointer.
   * @param nativeHandle The native options pointer.
   */
  protected abstract void close(long apiHandle, long nativeHandle);

  /**
   * Loads the provider's shared library (if necessary) and calls the create provider function.
   *
   * @param provider The OrtProvider for this options.
   * @param createFunction The create function.
   * @return The pointer to the native provider options object.
   * @throws OrtException If either the library load or provider options create call failed.
   */
  protected static long loadLibraryAndCreate(
      OrtProvider provider, OrtProviderSupplier createFunction) throws OrtException {
    // Shared providers need their libraries loaded before options can be defined.
    switch (provider) {
      case CUDA:
        if (!OnnxRuntime.extractCUDA()) {
          throw new OrtException(
              OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find CUDA shared provider");
        }
        break;
      case DNNL:
        if (!OnnxRuntime.extractDNNL()) {
          throw new OrtException(
              OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find DNNL shared provider");
        }
        break;
      case OPEN_VINO:
        if (!OnnxRuntime.extractOpenVINO()) {
          throw new OrtException(
              OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find OpenVINO shared provider");
        }
        break;
      case ROCM:
        if (!OnnxRuntime.extractROCM()) {
          throw new OrtException(
              OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find ROCm shared provider");
        }
        break;
      case TENSOR_RT:
        if (!OnnxRuntime.extractTensorRT()) {
          throw new OrtException(
              OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find TensorRT shared provider");
        }
        break;
    }

    return createFunction.create();
  }

  /** Functional interface mirroring a Java supplier, but can throw OrtException. */
  @FunctionalInterface
  public interface OrtProviderSupplier {
    /**
     * Calls the function to get the native pointer.
     *
     * @return The native pointer.
     * @throws OrtException If the create call failed.
     */
    public long create() throws OrtException;
  }
}
