//
// Copyright 2024 DXOS.org
//

import { Trigger } from '@dxos/async';
import { log } from '@dxos/log';

import type { GenerationStream } from './service';
import { type GenerationStreamEvent } from './types';
import { iterSSEMessages } from './util';

/**
 * Creates a stream from an SSE response.
 */
export const createGenerationStream = (response: Response, controller = new AbortController()): GenerationStream => {
  const iterator = async function* () {
    for await (const sse of iterSSEMessages(response, controller)) {
      if (sse.event === 'completion') {
        try {
          yield JSON.parse(sse.data);
        } catch (err) {
          log.error('could not parse message into JSON:', { data: sse.data, raw: sse.raw });
          throw err;
        }
      }

      if (
        sse.event === 'message_start' ||
        sse.event === 'message_delta' ||
        sse.event === 'message_stop' ||
        sse.event === 'content_block_start' ||
        sse.event === 'content_block_delta' ||
        sse.event === 'content_block_stop'
      ) {
        try {
          yield JSON.parse(sse.data);
        } catch (err) {
          log.error('could not parse message into JSON:', { data: sse.data, raw: sse.raw });
          throw err;
        }
      }

      if (sse.event === 'ping') {
        continue;
      }

      if (sse.event === 'error') {
        throw new Error(`Message generation error: ${sse.data}`);
      }
    }
  };

  return new GenerationStreamImpl(controller, iterator);
};

/**
 * Server-Sent Events (SSE) stream from the AI service.
 * https://html.spec.whatwg.org/multipage/server-sent-events.html#server-sent-events
 */
export class GenerationStreamImpl implements GenerationStream {
  /**
   * Trigger event when the stream is done.
   */
  private readonly _done = new Trigger();

  /**
   * Iterator over the stream.
   */
  private _iterator?: AsyncIterator<GenerationStreamEvent> = undefined;

  constructor(
    private readonly _controller: AbortController,
    private readonly _getIterator: () => AsyncIterableIterator<GenerationStreamEvent>,
  ) {}

  /**
   * Returns an iterator over the stream.
   */
  [Symbol.asyncIterator](): AsyncIterator<GenerationStreamEvent> {
    return this._createIterator();
  }

  /**
   * Aborts the stream.
   */
  abort() {
    this._controller.abort();
  }

  /**
   * Returns the complete message list generated by the AI service.
   */
  async complete() {
    await this._done.wait();
  }

  /**
   * Creates an iterator over the stream.
   */
  private _createIterator(): AsyncIterator<GenerationStreamEvent> {
    const self = this;
    return (this._iterator ??= (() => {
      const generator = async function* (this: GenerationStreamImpl) {
        try {
          for await (const event of self._getIterator()) {
            yield event;
          }
          this._done.wake();
        } catch (err: any) {
          this._done.throw(err);
          throw err;
        }
      };

      return generator.call(this);
    })());
  }
}
